diff --git a/pyproject.toml b/pyproject.toml index 6e6f561..4ca2865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,12 @@ readme = "README.md" requires-python = ">=3.11.0" version = "0.17.0" +[project.optional-dependencies] +cuda = [ + "jax[cuda12]", + "cupy" +] + [[tool.bver.file]] kind = "python" src = "pyproject.toml" diff --git a/src/sax/backends/__init__.py b/src/sax/backends/__init__.py index 77067fb..c417a7e 100644 --- a/src/sax/backends/__init__.py +++ b/src/sax/backends/__init__.py @@ -70,6 +70,21 @@ stacklevel=2, ) +try: + from .cuda import ( + analyze_circuit_cuda, + analyze_instances_cuda, + evaluate_circuit_cuda, + ) + + circuit_backends["cuda"] = ( + analyze_instances_cuda, + analyze_circuit_cuda, + evaluate_circuit_cuda, + ) +except ImportError: + pass + __all__ = [ "analyze_circuit", diff --git a/src/sax/backends/cuda.py b/src/sax/backends/cuda.py new file mode 100644 index 0000000..7c5f62d --- /dev/null +++ b/src/sax/backends/cuda.py @@ -0,0 +1,251 @@ +"""SAX CUDA Backend.""" + +from __future__ import annotations + +from typing import Any + +import cupy as cp # type: ignore[import-not-found] +import jax.numpy as jnp +import numpy as np + +import sax + +__all__ = [ + "analyze_circuit_cuda", + "analyze_instances_cuda", + "evaluate_circuit_cuda", +] + + +def _scoo_cupy(S: sax.SType) -> sax.SCoo: + """Convert an S-parameter to SCoo with CuPy arrays for values.""" + if isinstance(S, dict): + all_ports: dict[str, None] = {} + for p1, p2 in S: + all_ports.setdefault(p1, None) + all_ports.setdefault(p2, None) + ports_map = {p: int(i) for i, p in enumerate(all_ports)} + Si = np.array([ports_map[p] for _, p in S], dtype=np.int32) + Sj = np.array([ports_map[p] for p, _ in S], dtype=np.int32) + Sx = cp.stack([cp.asarray(v) for v in S.values()], -1) + return Si, Sj, Sx, ports_map + Si, Sj, Sx, ports_map = sax.scoo(S) + return ( + np.asarray(Si, dtype=np.int32), + np.asarray(Sj, dtype=np.int32), + cp.asarray(Sx), + ports_map, + ) + + +def _solve_cuda( + Ai: np.ndarray, Aj: np.ndarray, Ax: cp.ndarray, B: cp.ndarray +) -> cp.ndarray: + """Batched sparse solve using dense GPU operations. + + Builds dense matrices from the fixed sparsity pattern (Ai, Aj) with + per-batch values (Ax), then solves all systems in parallel using + cuBLAS/cuSOLVER batched routines. + + Args: + Ai: Row indices of non-zero values (topology, shared across batch). + Aj: Column indices of non-zero values (topology, shared across batch). + Ax: Non-zero values, shape (batch, nnz). + B: Right-hand side matrix, shape (n, n_rhs). + + Returns: + Solution matrix, shape (batch, n, n_rhs). + """ + n = int(B.shape[0]) + batch = Ax.shape[0] + Ai_cp = cp.asarray(Ai) + Aj_cp = cp.asarray(Aj) + A_dense = cp.zeros((batch, n, n), dtype=Ax.dtype) + A_dense[:, Ai_cp, Aj_cp] = Ax + return cp.linalg.solve(A_dense, cp.broadcast_to(B, (batch, *B.shape))) + + +def _coo_mul_vec( + Si: np.ndarray, Sj: np.ndarray, Sx: cp.ndarray, x: cp.ndarray +) -> cp.ndarray: + """Batched sparse matrix-dense matrix multiply using dense GPU matmul. + + Builds dense matrices from the fixed sparsity pattern (Si, Sj) with + per-batch values (Sx), then multiplies all in parallel. + + Args: + Si: Row indices of non-zero values (topology, shared across batch). + Sj: Column indices of non-zero values (topology, shared across batch). + Sx: Non-zero values, shape (batch, nnz). + x: Dense matrix to multiply, shape (batch, n, m). + + Returns: + Result of S @ x, shape (batch, n, m). + """ + n = x.shape[-2] + batch = Sx.shape[0] + Si_cp = cp.asarray(Si) + Sj_cp = cp.asarray(Sj) + S_dense = cp.zeros((batch, n, n), dtype=Sx.dtype) + S_dense[:, Si_cp, Sj_cp] = Sx + return S_dense @ x + + +def analyze_instances_cuda( + instances: dict[sax.InstanceName, sax.Instance], + models: dict[str, sax.Model], +) -> dict[str, sax.SCoo]: + """Analyze circuit instances for the CUDA backend. + + Args: + instances: Dictionary mapping instance names to instance definitions. + models: Dictionary mapping component names to their model functions. + + Returns: + Dictionary mapping instance names to their S-matrices in SCoo format. + """ + instances = sax.into[sax.Instances](instances) + model_names = set() + for i in instances.values(): + model_names.add(i["component"]) + dummy_models = {k: _scoo_cupy(models[k]()) for k in model_names} + dummy_instances = {} + for k, i in instances.items(): + dummy_instances[k] = dummy_models[i["component"]] + return dummy_instances + + +def analyze_circuit_cuda( + analyzed_instances: dict[sax.InstanceName, sax.SCoo], + nets: sax.Nets, + ports: sax.Ports, +) -> Any: # noqa: ANN401 + """Analyze circuit topology for the CUDA backend. + + Args: + analyzed_instances: Instance S-matrices from analyze_instances_cuda. + nets: List of net dictionaries with "p1" and "p2" keys. + ports: Dictionary mapping external port names to instance ports. + + Returns: + Tuple of pre-computed arrays for evaluate_circuit_cuda. + """ + inverse_ports = {v: k for k, v in ports.items()} + port_map = {k: i for i, k in enumerate(ports)} + + idx, Si, Sj, instance_ports = 0, [], [], {} + for name, instance in analyzed_instances.items(): + si, sj, _, ports_map = sax.scoo(instance) + Si.append(np.asarray(si) + idx) + Sj.append(np.asarray(sj) + idx) + instance_ports.update({f"{name},{p}": i + idx for p, i in ports_map.items()}) + idx += len(ports_map) + + n_col = idx + n_rhs = len(port_map) + + Si = np.concatenate(Si, -1) + Sj = np.concatenate(Sj, -1) + + pairs: set[tuple[int, int]] = set() + for net in nets: + p1_idx = int(instance_ports[net["p1"]]) + p2_idx = int(instance_ports[net["p2"]]) + pairs.add((p1_idx, p2_idx)) + pairs.add((p2_idx, p1_idx)) + sorted_pairs = sorted(pairs) + Ci = np.array([p[0] for p in sorted_pairs], dtype=np.int32) + Cj = np.array([p[1] for p in sorted_pairs], dtype=np.int32) + + Cextmap = { + int(instance_ports[k]): int(port_map[v]) for k, v in inverse_ports.items() + } + Cexti = cp.asarray(list(Cextmap.keys())) + Cextj = cp.asarray(list(Cextmap.values())) + Cext = cp.zeros((n_col, n_rhs), dtype=complex) + Cext[Cexti, Cextj] = 1.0 + + match_2d = Cj[None, :] == Si[:, None] + CSi = np.broadcast_to(Ci[None, :], match_2d.shape)[match_2d] + s_idx_grid = np.broadcast_to(np.arange(len(Si))[:, None], match_2d.shape) + cs_s_indices = s_idx_grid[match_2d] + CSj = Sj[cs_s_indices] + + Ii = Ij = np.arange(n_col) + I_CSi = np.concatenate([CSi, Ii], -1) + I_CSj = np.concatenate([CSj, Ij], -1) + + return ( + n_col, + cs_s_indices, + Si, + Sj, + Cext, + Cexti, + Cextj, + I_CSi, + I_CSj, + tuple((k, v[1]) for k, v in analyzed_instances.items()), + tuple(port_map), + ) + + +def evaluate_circuit_cuda( + analyzed: Any, # noqa: ANN401 + instances: dict[sax.InstanceName, sax.SType], +) -> sax.SDense: + """Evaluate circuit S-matrix using batched dense GPU operations. + + Uses CuPy batched dense linear algebra (cuBLAS/cuSOLVER) instead of + sequential sparse solves, giving much higher GPU utilization for + typical photonic circuit sizes. + + Args: + analyzed: Pre-computed analysis from analyze_circuit_cuda. + instances: Dictionary mapping instance names to evaluated S-matrices. + + Returns: + Circuit S-matrix in SDense format. + """ + ( + n_col, + cs_s_indices, + Si, + Sj, + Cext, + Cexti, + Cextj, + I_CSi, + I_CSj, + dummy_pms, + port_map, + ) = analyzed + + idx = 0 + Sx = [] + batch_shape = () + for name, _ in dummy_pms: + _, _, sx, ports_map = _scoo_cupy(instances[name]) + Sx.append(sx) + if len(sx.shape[:-1]) > len(batch_shape): + batch_shape = sx.shape[:-1] + idx += len(ports_map) + + Sx = cp.concatenate( + [cp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1 + ) + CSx = Sx[..., cs_s_indices] + Ix = cp.ones((*batch_shape, n_col)) + I_CSx = cp.concatenate([-CSx, Ix], -1) + + Sx = Sx.reshape(-1, Sx.shape[-1]) # n_lhs x N + I_CSx = I_CSx.reshape(-1, I_CSx.shape[-1]) # n_lhs x M + inv_I_CS_Cext = _solve_cuda(I_CSi, I_CSj, I_CSx, Cext) + S_inv_I_CS_Cext = _coo_mul_vec(Si, Sj, Sx, inv_I_CS_Cext) + + CextT_S_inv_I_CS_Cext = S_inv_I_CS_Cext[..., Cexti, :][..., :, Cextj] + + _, n, _ = CextT_S_inv_I_CS_Cext.shape + S = CextT_S_inv_I_CS_Cext.reshape(*batch_shape, n, n) + + return jnp.asarray(S), {p: i for i, p in enumerate(port_map)} diff --git a/src/sax/saxtypes/anymode.py b/src/sax/saxtypes/anymode.py index a44cacc..0734b6f 100644 --- a/src/sax/saxtypes/anymode.py +++ b/src/sax/saxtypes/anymode.py @@ -135,7 +135,7 @@ def val_backend(backend: Any) -> Backend: case "fg": backend = "filipsson_gunnar" - available_backends = ["filipsson_gunnar", "additive", "forward", "klu"] + available_backends = ["filipsson_gunnar", "additive", "forward", "klu", "cuda"] if backend not in available_backends: msg = ( f"Invalid backend '{backend}'. " @@ -146,7 +146,8 @@ def val_backend(backend: Any) -> Backend: Backend: TypeAlias = Annotated[ - Literal["filipsson_gunnar", "additive", "forward", "klu"], val(val_backend) + Literal["filipsson_gunnar", "additive", "forward", "klu", "cuda"], + val(val_backend), ] """Available SAX backend algorithms for circuit simulation.""" diff --git a/src/tests/test_backends.py b/src/tests/test_backends.py new file mode 100644 index 0000000..ec83047 --- /dev/null +++ b/src/tests/test_backends.py @@ -0,0 +1,123 @@ +import jax.numpy as jnp +import pytest + +import sax + + +@pytest.mark.parametrize("backend", ["cuda", "default", "klu", "filipsson_gunnar"]) +def test_backend(backend: str) -> None: + if backend not in sax.backends.circuit_backends: + pytest.skip(f"{backend} backend not available") + + instances = { + "lft": {"component": "coupler"}, + "top": {"component": "wg"}, + "rgt": {"component": "mmi"}, + } + nets = [ + {"p1": "lft,out0", "p2": "rgt,in0"}, + {"p1": "lft,out1", "p2": "top,in0"}, + {"p1": "top,out0", "p2": "rgt,in1"}, + ] + ports = {"in0": "lft,in0", "out0": "rgt,out0"} + models = { + "wg": lambda: { + ("in0", "out0"): -0.99477 - 0.10211j, + ("out0", "in0"): -0.99477 - 0.10211j, + }, + "mmi": lambda: { + ("in0", "out0"): 0.7071067811865476, + ("in0", "out1"): 0.7071067811865476j, + ("in1", "out0"): 0.7071067811865476j, + ("in1", "out1"): 0.7071067811865476, + ("out0", "in0"): 0.7071067811865476, + ("out1", "in0"): 0.7071067811865476j, + ("out0", "in1"): 0.7071067811865476j, + ("out1", "in1"): 0.7071067811865476, + }, + "coupler": lambda: ( + jnp.array( + [ + [ + 5.19688622e-06 - 1.19777138e-05j, + 6.30595625e-16 - 1.48061189e-17j, + -3.38542541e-01 - 6.15711852e-01j, + 5.80662654e-03 - 1.11068866e-02j, + -3.38542542e-01 - 6.15711852e-01j, + -5.80662660e-03 + 1.11068866e-02j, + ], + [ + 8.59445189e-16 - 8.29783014e-16j, + -2.08640825e-06 + 8.17315497e-06j, + 2.03847666e-03 - 2.10649131e-03j, + 5.30509661e-01 + 4.62504708e-01j, + -2.03847666e-03 + 2.10649129e-03j, + 5.30509662e-01 + 4.62504708e-01j, + ], + [ + -3.38542541e-01 - 6.15711852e-01j, + 2.03847660e-03 - 2.10649129e-03j, + 7.60088070e-06 + 9.07340423e-07j, + 2.79292426e-09 + 2.79093547e-07j, + 5.07842364e-06 + 2.16385350e-06j, + -6.84244232e-08 - 5.00486817e-07j, + ], + [ + 5.80662707e-03 - 1.11068869e-02j, + 5.30509661e-01 + 4.62504708e-01j, + 2.79291895e-09 + 2.79093540e-07j, + -4.55645798e-06 + 1.50570403e-06j, + 6.84244128e-08 + 5.00486817e-07j, + -3.55812153e-06 + 4.59781091e-07j, + ], + [ + -3.38542541e-01 - 6.15711852e-01j, + -2.03847672e-03 + 2.10649131e-03j, + 5.07842364e-06 + 2.16385349e-06j, + 6.84244230e-08 + 5.00486816e-07j, + 7.60088070e-06 + 9.07340425e-07j, + -2.79292467e-09 - 2.79093547e-07j, + ], + [ + -5.80662607e-03 + 1.11068863e-02j, + 5.30509662e-01 + 4.62504708e-01j, + -6.84244296e-08 - 5.00486825e-07j, + -3.55812153e-06 + 4.59781093e-07j, + -2.79293217e-09 - 2.79093547e-07j, + -4.55645798e-06 + 1.50570403e-06j, + ], + ] + ), + {"in0": 0, "out0": 2, "out1": 4}, + ), + } + + ( + analyze_instances, + analyze_circuit, + evaluate_circuit, + ) = sax.backends.circuit_backends[backend] + + analyzed_instances = analyze_instances(instances, models) + analyzed_circuit = analyze_circuit(analyzed_instances, nets, ports) + sdict_backend = sax.sdict( + evaluate_circuit( + analyzed_circuit, + {k: models[v["component"]]() for k, v in instances.items()}, + ) + ) + + analyzed_instances = sax.backends.analyze_instances_klu(instances, models) # type: ignore[arg-type] + analyzed_circuit = sax.backends.analyze_circuit_klu(analyzed_instances, nets, ports) # type: ignore[arg-type] + sdict_klu = sax.sdict( + sax.backends.evaluate_circuit_klu( + analyzed_circuit, + {k: models[v["component"]]() for k, v in instances.items()}, + ) + ) + + # Compare to klu backend as source of truth + for k in sdict_klu: + val_klu = sdict_klu[k] + val_backend = sdict_backend[k] + assert abs(val_klu - val_backend) < 1e-5 diff --git a/src/tests/test_quick_start.py b/src/tests/test_quick_start.py new file mode 100644 index 0000000..f4221c1 --- /dev/null +++ b/src/tests/test_quick_start.py @@ -0,0 +1,75 @@ +import jax.numpy as jnp + +import sax + + +def test_quick_start() -> None: + """Runs the core parts of the quick start notebook. + + This does not use jax.jit, to support the cupy backend. + """ + + def coupler(coupling: float = 0.5) -> sax.SDict: + kappa = coupling**0.5 + tau = (1 - coupling) ** 0.5 + return sax.reciprocal( + { + ("in0", "out0"): tau, + ("in0", "out1"): 1j * kappa, + ("in1", "out0"): 1j * kappa, + ("in1", "out1"): tau, + } + ) + + coupler(coupling=0.3) + + def waveguide( + wl: float = 1.55, + wl0: float = 1.55, + neff: float = 2.34, + ng: float = 3.4, + length: float = 10.0, + loss: float = 0.0, + ) -> sax.SDict: + dwl = wl - wl0 + dneff_dwl = (ng - neff) / wl0 + neff = neff - dwl * dneff_dwl + phase = 2 * jnp.pi * neff * length / wl + transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase) + return sax.reciprocal( + { + ("in0", "out0"): transmission, + } + ) + + mzi, _info = sax.circuit( + netlist={ + "instances": { + "lft": "coupler", + "top": "waveguide", + "btm": "waveguide", + "rgt": "coupler", + }, + "connections": { + "lft,out0": "btm,in0", + "btm,out0": "rgt,in0", + "lft,out1": "top,in0", + "top,out0": "rgt,in1", + }, + "ports": { + "in0": "lft,in0", + "in1": "lft,in1", + "out0": "rgt,out0", + "out1": "rgt,out1", + }, + }, + models={ + "coupler": coupler, + "waveguide": waveguide, + }, + ) + + mzi() + mzi(top={"length": 25.0}, btm={"length": 15.0}) + wl = jnp.linspace(1.51, 1.59, 1000) + mzi(wl=wl, top={"length": 25.0}, btm={"length": 15.0})