From 80cf2cbe719e4eefcfc8e3f85643b94d463a777d Mon Sep 17 00:00:00 2001 From: Joaquin Matres <4514346+joamatab@users.noreply.github.com> Date: Wed, 22 May 2024 14:01:34 -0400 Subject: [PATCH 1/9] badd cuda --- sax/backends/__init__.py | 24 +++++ sax/backends/cuda.py | 218 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 sax/backends/cuda.py diff --git a/sax/backends/__init__.py b/sax/backends/__init__.py index d3ba1879..97a42bc4 100644 --- a/sax/backends/__init__.py +++ b/sax/backends/__init__.py @@ -60,6 +60,30 @@ "better performance during circuit evaluation!" ) +try: + from .cuda import analyze_circuit_klu, analyze_instances_klu, evaluate_circuit_klu + + circuit_backends["klu_cuda"] = ( + analyze_instances_klu, + analyze_circuit_klu, + evaluate_circuit_klu, + ) + circuit_backends["default"] = ( + analyze_instances_klu, + analyze_circuit_klu, + evaluate_circuit_klu, + ) +except ImportError: + circuit_backends["default"] = ( + analyze_instances_fg, + analyze_circuit_fg, + evaluate_circuit_fg, + ) + warnings.warn( + "cuda not found. Please install klujax, cupy and cupyx for " + "better performance during circuit evaluation!" + ) + def analyze_instances( instances: Dict[str, Component], diff --git a/sax/backends/cuda.py b/sax/backends/cuda.py new file mode 100644 index 00000000..aaa5edc5 --- /dev/null +++ b/sax/backends/cuda.py @@ -0,0 +1,218 @@ +""" SAX KLU Backend """ + +from __future__ import annotations + +from typing import Any, Dict + +import cupy as cp +import cupyx +import klujax # Assuming klujax is compatible with cupy or provide a similar interface +from natsort import natsorted + +from ..netlist import Component +from ..saxtypes import Model, SCoo, SDense, SType, scoo + + +def solve_klu(Ai, Aj, Ax, B): + """ + Custom solver using CuPy for sparse matrix solve. + + Args: + Ai (array): Row indices of non-zero values in the sparse matrix. + Aj (array): Column indices of non-zero values in the sparse matrix. + Ax (array): Non-zero values of the sparse matrix. + B (array): Right-hand side matrix to solve for. + + Returns: array: Solution matrix. + """ + # Create sparse matrix in COO format + A_coo = cupyx.scipy.sparse.coo_matrix((Ax, (Ai, Aj))) + + # Convert to CSR format for solving + A_csr = A_coo.tocsr() + + # Solve the linear system + solution = cupyx.scipy.sparse.linalg.spsolve(A_csr, B) + + return solution + + +def coo_mul_vec(Si, Sj, Sx, x): + """ + COO matrix-vector multiplication using CuPy. + + Args: + Si (array): Row indices of non-zero values in the sparse matrix. + Sj (array): Column indices of non-zero values in the sparse matrix. + Sx (array): Non-zero values of the sparse matrix. + x (array): Dense vector to multiply with the sparse matrix. + + Returns: + array: Resulting vector from the multiplication. + """ + # Create sparse matrix in COO format + S_coo = cupyx.scipy.sparse.coo_matrix((Sx, (Si, Sj))) + + # Perform the matrix-vector multiplication + result = S_coo.dot(x) + return result + + +def analyze_instances_klu( + instances: Dict[str, Component], + models: Dict[str, Model], +) -> Dict[str, SCoo]: + instances, instances_old = {}, instances + for k, v in instances_old.items(): + if not isinstance(v, Component): + v = Component(**v) + instances[k] = v + model_names = set() + for i in instances.values(): + if i.info and "model" in i.info and isinstance(i.info["model"], str): + model_names.add(str(i.info["model"])) + else: + model_names.add(str(i.component)) + dummy_models = {k: scoo(models[k]()) for k in model_names} + dummy_instances = {} + for k, i in instances.items(): + if i.info and "model" in i.info and isinstance(i.info["model"], str): + dummy_instances[k] = dummy_models[str(i.info["model"])] + else: + dummy_instances[k] = dummy_models[str(i.component)] + return dummy_instances + + +def analyze_circuit_klu( + analyzed_instances: Dict[str, SCoo], + connections: Dict[str, str], + ports: Dict[str, str], +) -> Any: + connections = {**connections, **{v: k for k, v in connections.items()}} + inverse_ports = {v: k for k, v in ports.items()} + port_map = {k: i for i, k in enumerate(ports)} + + idx, Si, Sj, instance_ports = 0, [], [], {} + for name, instance in analyzed_instances.items(): + si, sj, _, ports_map = scoo(instance) + Si.append(si + idx) + Sj.append(sj + idx) + instance_ports.update({f"{name},{p}": i + idx for p, i in ports_map.items()}) + idx += len(ports_map) + + n_col = idx + n_rhs = len(port_map) + + Si = cp.concatenate(Si, -1) + Sj = cp.concatenate(Sj, -1) + + Cmap = { + int(instance_ports[k]): int(instance_ports[v]) for k, v in connections.items() + } + Ci = cp.array(list(Cmap.keys()), dtype=cp.int32) + Cj = cp.array(list(Cmap.values()), dtype=cp.int32) + + Cextmap = { + int(instance_ports[k]): int(port_map[v]) for k, v in inverse_ports.items() + } + Cexti = cp.stack(list(Cextmap.keys()), 0) + Cextj = cp.stack(list(Cextmap.values()), 0) + Cext = cp.zeros((n_col, n_rhs), dtype=complex).at[Cexti, Cextj].set(1.0) + + mask = Cj[None, :] == Si[:, None] + CSi = cp.broadcast_to(Ci[None, :], mask.shape)[mask] + + mask = (Cj[:, None] == Si[None, :]).any(0) + CSj = Sj[mask] + + Ii = Ij = cp.arange(n_col) + I_CSi = cp.concatenate([CSi, Ii], -1) + I_CSj = cp.concatenate([CSj, Ij], -1) + return ( + n_col, + mask, + Si, + Sj, + Cext, + Cexti, + Cextj, + I_CSi, + I_CSj, + tuple((k, v[1]) for k, v in analyzed_instances.items()), + tuple(port_map), + ) + + +def evaluate_circuit_klu(analyzed: Any, instances: Dict[str, SType]) -> SDense: + ( + n_col, + mask, + Si, + Sj, + Cext, + Cexti, + Cextj, + I_CSi, + I_CSj, + dummy_pms, + port_map, + ) = analyzed + + idx = 0 + Sx = [] + batch_shape = () + for name, pm_ in dummy_pms: + _, _, sx, ports_map = scoo(instances[name]) + Sx.append(sx) + if len(sx.shape[:-1]) > len(batch_shape): + batch_shape = sx.shape[:-1] + idx += len(ports_map) + + Sx = cp.concatenate( + [cp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1 + ) + CSx = Sx[..., mask] + Ix = cp.ones((*batch_shape, n_col)) + I_CSx = cp.concatenate([-CSx, Ix], -1) + + Sx = Sx.reshape(-1, Sx.shape[-1]) # n_lhs x N + I_CSx = I_CSx.reshape(-1, I_CSx.shape[-1]) # n_lhs x M + inv_I_CS_Cext = solve_klu(I_CSi, I_CSj, I_CSx, Cext) + S_inv_I_CS_Cext = coo_mul_vec(Si, Sj, Sx, inv_I_CS_Cext) + + CextT_S_inv_I_CS_Cext = S_inv_I_CS_Cext[..., Cexti, :][..., :, Cextj] + + _, n, _ = CextT_S_inv_I_CS_Cext.shape + S = CextT_S_inv_I_CS_Cext.reshape(*batch_shape, n, n) + + return S, {p: i for i, p in enumerate(port_map)} + + +def _get_instance_ports(connections: Dict[str, str], ports: Dict[str, str]): + instance_ports = {} + for connection in connections.items(): + for ip in connection: + i, p = ip.split(",") + if i not in instance_ports: + instance_ports[i] = set() + instance_ports[i].add(p) + for ip in ports.values(): + i, p = ip.split(",") + if i not in instance_ports: + instance_ports[i] = set() + instance_ports[i].add(p) + return {k: natsorted(v) for k, v in instance_ports.items()} + + +def _get_dummy_instances(connections, ports): + """no longer used. deprecated by analyze_instances_klu.""" + instance_ports = _get_instance_ports(connections, ports) + dummy_instances = {} + for name, ports in instance_ports.items(): + num_ports = len(ports) + pm = {p: i for i, p in enumerate(ports)} + ij = cp.mgrid[:num_ports, :num_ports] + i = ij[0].ravel() + j = ij[1].ravel() + dummy_instances[name] = (i, j, None, pm) + return dummy_instances From 734c32bfa6533e5f7874de24dfdafe19e8c4497b Mon Sep 17 00:00:00 2001 From: Joaquin Matres <4514346+joamatab@users.noreply.github.com> Date: Wed, 22 May 2024 14:26:46 -0400 Subject: [PATCH 2/9] add cuda to optional packages --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b66eb70d..17f9c7af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,11 @@ dev = [ # "meep", ] +cuda = [ + "jax[cuda]", + "cupy", +] + [tool.setuptools.packages.find] where = ["."] include = ["sax", "sax.nn", "sax.backends"] From be4acb9b5902f0de2964a34d89182ecc9bb5c133 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 22 May 2024 13:46:53 -0700 Subject: [PATCH 3/9] Add CUDA backend. --- sax/backends/__init__.py | 22 +++++++----- sax/backends/cuda.py | 78 ++++++++++++++++------------------------ 2 files changed, 44 insertions(+), 56 deletions(-) diff --git a/sax/backends/__init__.py b/sax/backends/__init__.py index 97a42bc4..aae94108 100644 --- a/sax/backends/__init__.py +++ b/sax/backends/__init__.py @@ -61,17 +61,21 @@ ) try: - from .cuda import analyze_circuit_klu, analyze_instances_klu, evaluate_circuit_klu + from .cuda import ( + analyze_circuit_cuda, + analyze_instances_cuda, + evaluate_circuit_cuda, + ) - circuit_backends["klu_cuda"] = ( - analyze_instances_klu, - analyze_circuit_klu, - evaluate_circuit_klu, + circuit_backends["cuda"] = ( + analyze_instances_cuda, + analyze_circuit_cuda, + evaluate_circuit_cuda, ) circuit_backends["default"] = ( - analyze_instances_klu, - analyze_circuit_klu, - evaluate_circuit_klu, + analyze_instances_cuda, + analyze_circuit_cuda, + evaluate_circuit_cuda, ) except ImportError: circuit_backends["default"] = ( @@ -80,7 +84,7 @@ evaluate_circuit_fg, ) warnings.warn( - "cuda not found. Please install klujax, cupy and cupyx for " + "cupy not found. Please install cupy for " "better performance during circuit evaluation!" ) diff --git a/sax/backends/cuda.py b/sax/backends/cuda.py index aaa5edc5..d55abe55 100644 --- a/sax/backends/cuda.py +++ b/sax/backends/cuda.py @@ -1,4 +1,4 @@ -""" SAX KLU Backend """ +""" SAX CUDA Backend """ from __future__ import annotations @@ -6,14 +6,20 @@ import cupy as cp import cupyx -import klujax # Assuming klujax is compatible with cupy or provide a similar interface +import cupyx.scipy.sparse.linalg +import jax.numpy as jnp from natsort import natsorted from ..netlist import Component from ..saxtypes import Model, SCoo, SDense, SType, scoo -def solve_klu(Ai, Aj, Ax, B): +def scoo_cupy(S): + Si, Sj, Sx, ports_map = scoo(S) + return cp.asarray(Si), cp.asarray(Sj), cp.asarray(Sx), ports_map + + +def solve_cuda(Ai, Aj, Ax, B): """ Custom solver using CuPy for sparse matrix solve. @@ -25,6 +31,9 @@ def solve_klu(Ai, Aj, Ax, B): Returns: array: Solution matrix. """ + # TODO: Maybe the shape of Ax is wrong? Unsure -- the KLU backend uses jax.vmap. + Ax = Ax[0] + # Create sparse matrix in COO format A_coo = cupyx.scipy.sparse.coo_matrix((Ax, (Ai, Aj))) @@ -50,6 +59,9 @@ def coo_mul_vec(Si, Sj, Sx, x): Returns: array: Resulting vector from the multiplication. """ + # TODO: Maybe the shape of Ax is wrong? Unsure -- the KLU backend uses jax.vmap. + Sx = Sx[0] + # Create sparse matrix in COO format S_coo = cupyx.scipy.sparse.coo_matrix((Sx, (Si, Sj))) @@ -58,7 +70,7 @@ def coo_mul_vec(Si, Sj, Sx, x): return result -def analyze_instances_klu( +def analyze_instances_cuda( instances: Dict[str, Component], models: Dict[str, Model], ) -> Dict[str, SCoo]: @@ -73,7 +85,7 @@ def analyze_instances_klu( model_names.add(str(i.info["model"])) else: model_names.add(str(i.component)) - dummy_models = {k: scoo(models[k]()) for k in model_names} + dummy_models = {k: scoo_cupy(models[k]()) for k in model_names} dummy_instances = {} for k, i in instances.items(): if i.info and "model" in i.info and isinstance(i.info["model"], str): @@ -83,7 +95,7 @@ def analyze_instances_klu( return dummy_instances -def analyze_circuit_klu( +def analyze_circuit_cuda( analyzed_instances: Dict[str, SCoo], connections: Dict[str, str], ports: Dict[str, str], @@ -94,9 +106,9 @@ def analyze_circuit_klu( idx, Si, Sj, instance_ports = 0, [], [], {} for name, instance in analyzed_instances.items(): - si, sj, _, ports_map = scoo(instance) - Si.append(si + idx) - Sj.append(sj + idx) + si, sj, _, ports_map = scoo_cupy(instance) + Si.append(cp.asarray(si + idx)) + Sj.append(cp.asarray(sj + idx)) instance_ports.update({f"{name},{p}": i + idx for p, i in ports_map.items()}) idx += len(ports_map) @@ -115,9 +127,10 @@ def analyze_circuit_klu( Cextmap = { int(instance_ports[k]): int(port_map[v]) for k, v in inverse_ports.items() } - Cexti = cp.stack(list(Cextmap.keys()), 0) - Cextj = cp.stack(list(Cextmap.values()), 0) - Cext = cp.zeros((n_col, n_rhs), dtype=complex).at[Cexti, Cextj].set(1.0) + Cexti = cp.asarray(list(Cextmap.keys())) + Cextj = cp.asarray(list(Cextmap.values())) + Cext = cp.zeros((n_col, n_rhs), dtype=complex) + Cext[Cexti, Cextj] = 1.0 mask = Cj[None, :] == Si[:, None] CSi = cp.broadcast_to(Ci[None, :], mask.shape)[mask] @@ -143,7 +156,7 @@ def analyze_circuit_klu( ) -def evaluate_circuit_klu(analyzed: Any, instances: Dict[str, SType]) -> SDense: +def evaluate_circuit_cuda(analyzed: Any, instances: Dict[str, SType]) -> SDense: ( n_col, mask, @@ -162,7 +175,7 @@ def evaluate_circuit_klu(analyzed: Any, instances: Dict[str, SType]) -> SDense: Sx = [] batch_shape = () for name, pm_ in dummy_pms: - _, _, sx, ports_map = scoo(instances[name]) + _, _, sx, ports_map = scoo_cupy(instances[name]) Sx.append(sx) if len(sx.shape[:-1]) > len(batch_shape): batch_shape = sx.shape[:-1] @@ -177,42 +190,13 @@ def evaluate_circuit_klu(analyzed: Any, instances: Dict[str, SType]) -> SDense: Sx = Sx.reshape(-1, Sx.shape[-1]) # n_lhs x N I_CSx = I_CSx.reshape(-1, I_CSx.shape[-1]) # n_lhs x M - inv_I_CS_Cext = solve_klu(I_CSi, I_CSj, I_CSx, Cext) + inv_I_CS_Cext = solve_cuda(I_CSi, I_CSj, I_CSx, Cext) S_inv_I_CS_Cext = coo_mul_vec(Si, Sj, Sx, inv_I_CS_Cext) CextT_S_inv_I_CS_Cext = S_inv_I_CS_Cext[..., Cexti, :][..., :, Cextj] - _, n, _ = CextT_S_inv_I_CS_Cext.shape + # TODO: Check that n should be from shape[-2]. We dropped a dimension somewhere. + n = CextT_S_inv_I_CS_Cext.shape[-2] S = CextT_S_inv_I_CS_Cext.reshape(*batch_shape, n, n) - return S, {p: i for i, p in enumerate(port_map)} - - -def _get_instance_ports(connections: Dict[str, str], ports: Dict[str, str]): - instance_ports = {} - for connection in connections.items(): - for ip in connection: - i, p = ip.split(",") - if i not in instance_ports: - instance_ports[i] = set() - instance_ports[i].add(p) - for ip in ports.values(): - i, p = ip.split(",") - if i not in instance_ports: - instance_ports[i] = set() - instance_ports[i].add(p) - return {k: natsorted(v) for k, v in instance_ports.items()} - - -def _get_dummy_instances(connections, ports): - """no longer used. deprecated by analyze_instances_klu.""" - instance_ports = _get_instance_ports(connections, ports) - dummy_instances = {} - for name, ports in instance_ports.items(): - num_ports = len(ports) - pm = {p: i for i, p in enumerate(ports)} - ij = cp.mgrid[:num_ports, :num_ports] - i = ij[0].ravel() - j = ij[1].ravel() - dummy_instances[name] = (i, j, None, pm) - return dummy_instances + return jnp.asarray(S), {p: i for i, p in enumerate(port_map)} From b1955a87db0fe7247be3ab26d0d36bc9913fe4f1 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 22 May 2024 13:48:58 -0700 Subject: [PATCH 4/9] Use jax[cuda12]. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 17f9c7af..1f0cc1a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ dev = [ ] cuda = [ - "jax[cuda]", + "jax[cuda12]", "cupy", ] From 6e6c69357e79de9ab7ddff147f3c0c2278214bad Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 22 May 2024 13:49:26 -0700 Subject: [PATCH 5/9] Add backends test script. --- tests/test_backends.py | 118 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 tests/test_backends.py diff --git a/tests/test_backends.py b/tests/test_backends.py new file mode 100644 index 00000000..5254cea4 --- /dev/null +++ b/tests/test_backends.py @@ -0,0 +1,118 @@ +import pytest + +import sax +import jax.numpy as jnp + + +@pytest.mark.parametrize("backend", ["cuda", "default", "klu", "fg"]) +def test_backend(backend): + instances = { + "lft": {"component": "coupler"}, + "top": {"component": "wg"}, + "rgt": {"component": "mmi"}, + } + connections = {"lft,out0": "rgt,in0", "lft,out1": "top,in0", "top,out0": "rgt,in1"} + ports = {"in0": "lft,in0", "out0": "rgt,out0"} + models = { + "wg": lambda: { + ("in0", "out0"): -0.99477 - 0.10211j, + ("out0", "in0"): -0.99477 - 0.10211j, + }, + "mmi": lambda: { + ("in0", "out0"): 0.7071067811865476, + ("in0", "out1"): 0.7071067811865476j, + ("in1", "out0"): 0.7071067811865476j, + ("in1", "out1"): 0.7071067811865476, + ("out0", "in0"): 0.7071067811865476, + ("out1", "in0"): 0.7071067811865476j, + ("out0", "in1"): 0.7071067811865476j, + ("out1", "in1"): 0.7071067811865476, + }, + "coupler": lambda: ( + jnp.array( + [ + [ + 5.19688622e-06 - 1.19777138e-05j, + 6.30595625e-16 - 1.48061189e-17j, + -3.38542541e-01 - 6.15711852e-01j, + 5.80662654e-03 - 1.11068866e-02j, + -3.38542542e-01 - 6.15711852e-01j, + -5.80662660e-03 + 1.11068866e-02j, + ], + [ + 8.59445189e-16 - 8.29783014e-16j, + -2.08640825e-06 + 8.17315497e-06j, + 2.03847666e-03 - 2.10649131e-03j, + 5.30509661e-01 + 4.62504708e-01j, + -2.03847666e-03 + 2.10649129e-03j, + 5.30509662e-01 + 4.62504708e-01j, + ], + [ + -3.38542541e-01 - 6.15711852e-01j, + 2.03847660e-03 - 2.10649129e-03j, + 7.60088070e-06 + 9.07340423e-07j, + 2.79292426e-09 + 2.79093547e-07j, + 5.07842364e-06 + 2.16385350e-06j, + -6.84244232e-08 - 5.00486817e-07j, + ], + [ + 5.80662707e-03 - 1.11068869e-02j, + 5.30509661e-01 + 4.62504708e-01j, + 2.79291895e-09 + 2.79093540e-07j, + -4.55645798e-06 + 1.50570403e-06j, + 6.84244128e-08 + 5.00486817e-07j, + -3.55812153e-06 + 4.59781091e-07j, + ], + [ + -3.38542541e-01 - 6.15711852e-01j, + -2.03847672e-03 + 2.10649131e-03j, + 5.07842364e-06 + 2.16385349e-06j, + 6.84244230e-08 + 5.00486816e-07j, + 7.60088070e-06 + 9.07340425e-07j, + -2.79292467e-09 - 2.79093547e-07j, + ], + [ + -5.80662607e-03 + 1.11068863e-02j, + 5.30509662e-01 + 4.62504708e-01j, + -6.84244296e-08 - 5.00486825e-07j, + -3.55812153e-06 + 4.59781093e-07j, + -2.79293217e-09 - 2.79093547e-07j, + -4.55645798e-06 + 1.50570403e-06j, + ], + ] + ), + {"in0": 0, "out0": 2, "out1": 4}, + ), + } + + ( + analyze_instances, + analyze_circuit, + evaluate_circuit, + ) = sax.backends.circuit_backends[backend] + + analyzed_instances = analyze_instances(instances, models) + analyzed_circuit = analyze_circuit(analyzed_instances, connections, ports) + sdict_backend = sax.sdict( + evaluate_circuit( + analyzed_circuit, + {k: models[v["component"]]() for k, v in instances.items()}, + ) + ) + + analyzed_instances = sax.backends.analyze_instances_klu(instances, models) + analyzed_circuit = sax.backends.analyze_circuit_klu( + analyzed_instances, connections, ports + ) + sdict_klu = sax.sdict( + sax.backends.evaluate_circuit_klu( + analyzed_circuit, + {k: models[v["component"]]() for k, v in instances.items()}, + ) + ) + + # Compare to klu backend as source of truth + for k in sdict_klu: + val_klu = sdict_klu[k] + val_backend = sdict_backend[k] + assert abs(val_klu - val_backend) < 1e-5 From e3743b8d97f34d27dd4906217459493e6959d74f Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 22 May 2024 13:49:45 -0700 Subject: [PATCH 6/9] Apply pre-commit to test_nbs.py. --- tests/test_nbs.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_nbs.py b/tests/test_nbs.py index 84c37b0f..9d113f4b 100644 --- a/tests/test_nbs.py +++ b/tests/test_nbs.py @@ -9,24 +9,28 @@ NBS_DIR = os.path.join(TEST_DIR, "nbs") NBS_FAIL_DIR = os.path.join(NBS_DIR, "failed") + def get_kernel(): kernel = os.environ.get("CONDA_DEFAULT_ENV", "base") if kernel == "base": kernel = "python3" return kernel + shutil.rmtree(NBS_FAIL_DIR, ignore_errors=True) os.mkdir(NBS_FAIL_DIR) + def _find_notebooks(*dir): dir = os.path.abspath(os.path.join(TEST_DIR, *dir)) for root, _, files in os.walk(dir): for file in files: - if ('checkpoint' in file) or (not file.endswith('.ipynb')): + if ("checkpoint" in file) or (not file.endswith(".ipynb")): continue yield os.path.join(root, file) -@pytest.mark.parametrize('path', sorted(_find_notebooks('nbs'))) + +@pytest.mark.parametrize("path", sorted(_find_notebooks("nbs"))) def test_nbs(path): fn = os.path.basename(path) nb = load_notebook_node(path) @@ -38,4 +42,3 @@ def test_nbs(path): output_path=None, ) raise_for_execution_errors(nb, os.path.join(NBS_FAIL_DIR, fn)) - From 5078cd69298f871f4e4d84545d381aa585ff644e Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Thu, 23 May 2024 14:23:32 -0700 Subject: [PATCH 7/9] Use loops to handle missing dimension in CUDA backend. --- sax/backends/__init__.py | 7 ++----- sax/backends/cuda.py | 36 +++++++++++++++++------------------- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/sax/backends/__init__.py b/sax/backends/__init__.py index aae94108..d319d9d7 100644 --- a/sax/backends/__init__.py +++ b/sax/backends/__init__.py @@ -78,11 +78,8 @@ evaluate_circuit_cuda, ) except ImportError: - circuit_backends["default"] = ( - analyze_instances_fg, - analyze_circuit_fg, - evaluate_circuit_fg, - ) + default_backend = "klu" if "klu" in circuit_backends else "fg" + circuit_backends["default"] = circuit_backends[default_backend] warnings.warn( "cupy not found. Please install cupy for " "better performance during circuit evaluation!" diff --git a/sax/backends/cuda.py b/sax/backends/cuda.py index d55abe55..448a84e8 100644 --- a/sax/backends/cuda.py +++ b/sax/backends/cuda.py @@ -31,19 +31,18 @@ def solve_cuda(Ai, Aj, Ax, B): Returns: array: Solution matrix. """ - # TODO: Maybe the shape of Ax is wrong? Unsure -- the KLU backend uses jax.vmap. - Ax = Ax[0] + results = [] + for Ax_mat in Ax: + # Create sparse matrix in COO format + A_coo = cupyx.scipy.sparse.coo_matrix((Ax_mat, (Ai, Aj))) - # Create sparse matrix in COO format - A_coo = cupyx.scipy.sparse.coo_matrix((Ax, (Ai, Aj))) + # Convert to CSR format for solving + A_csr = A_coo.tocsr() - # Convert to CSR format for solving - A_csr = A_coo.tocsr() + # Solve the linear system + results.append(cupyx.scipy.sparse.linalg.spsolve(A_csr, B)) - # Solve the linear system - solution = cupyx.scipy.sparse.linalg.spsolve(A_csr, B) - - return solution + return cp.asarray(results) def coo_mul_vec(Si, Sj, Sx, x): @@ -59,15 +58,15 @@ def coo_mul_vec(Si, Sj, Sx, x): Returns: array: Resulting vector from the multiplication. """ - # TODO: Maybe the shape of Ax is wrong? Unsure -- the KLU backend uses jax.vmap. - Sx = Sx[0] + results = [] + for Sx_mat, x_vec in zip(Sx, x): + # Create sparse matrix in COO format + S_coo = cupyx.scipy.sparse.coo_matrix((Sx_mat, (Si, Sj))) - # Create sparse matrix in COO format - S_coo = cupyx.scipy.sparse.coo_matrix((Sx, (Si, Sj))) + # Perform the matrix-vector multiplication + results.append(S_coo.dot(x_vec)) - # Perform the matrix-vector multiplication - result = S_coo.dot(x) - return result + return cp.asarray(results) def analyze_instances_cuda( @@ -195,8 +194,7 @@ def evaluate_circuit_cuda(analyzed: Any, instances: Dict[str, SType]) -> SDense: CextT_S_inv_I_CS_Cext = S_inv_I_CS_Cext[..., Cexti, :][..., :, Cextj] - # TODO: Check that n should be from shape[-2]. We dropped a dimension somewhere. - n = CextT_S_inv_I_CS_Cext.shape[-2] + _, n, _ = CextT_S_inv_I_CS_Cext.shape S = CextT_S_inv_I_CS_Cext.reshape(*batch_shape, n, n) return jnp.asarray(S), {p: i for i, p in enumerate(port_map)} From d7ba9c26863e6a03b9b01331a0399d151bfd89e8 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Thu, 23 May 2024 14:35:36 -0700 Subject: [PATCH 8/9] Add quick start test. --- tests/test_quick_start.py | 95 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/test_quick_start.py diff --git a/tests/test_quick_start.py b/tests/test_quick_start.py new file mode 100644 index 00000000..01539cae --- /dev/null +++ b/tests/test_quick_start.py @@ -0,0 +1,95 @@ +import jax +import jax.example_libraries.optimizers as opt +import jax.numpy as jnp +import sax + + +def test_quick_start(): + """Runs the core parts of the quick start notebook. + + This does not use jax.jit, to support the cupy backend. + """ + coupling = 0.5 + kappa = coupling**0.5 + tau = (1 - coupling) ** 0.5 + coupler_dict = { + ("in0", "out0"): tau, + ("out0", "in0"): tau, + ("in0", "out1"): 1j * kappa, + ("out1", "in0"): 1j * kappa, + ("in1", "out0"): 1j * kappa, + ("out0", "in1"): 1j * kappa, + ("in1", "out1"): tau, + ("out1", "in1"): tau, + } + + coupler_dict = sax.reciprocal( + { + ("in0", "out0"): tau, + ("in0", "out1"): 1j * kappa, + ("in1", "out0"): 1j * kappa, + ("in1", "out1"): tau, + } + ) + + def coupler(coupling=0.5) -> sax.SDict: + kappa = coupling**0.5 + tau = (1 - coupling) ** 0.5 + coupler_dict = sax.reciprocal( + { + ("in0", "out0"): tau, + ("in0", "out1"): 1j * kappa, + ("in1", "out0"): 1j * kappa, + ("in1", "out1"): tau, + } + ) + return coupler_dict + + coupler(coupling=0.3) + + def waveguide( + wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0 + ) -> sax.SDict: + dwl = wl - wl0 + dneff_dwl = (ng - neff) / wl0 + neff = neff - dwl * dneff_dwl + phase = 2 * jnp.pi * neff * length / wl + transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase) + sdict = sax.reciprocal( + { + ("in0", "out0"): transmission, + } + ) + return sdict + + mzi, info = sax.circuit( + netlist={ + "instances": { + "lft": "coupler", + "top": "waveguide", + "btm": "waveguide", + "rgt": "coupler", + }, + "connections": { + "lft,out0": "btm,in0", + "btm,out0": "rgt,in0", + "lft,out1": "top,in0", + "top,out0": "rgt,in1", + }, + "ports": { + "in0": "lft,in0", + "in1": "lft,in1", + "out0": "rgt,out0", + "out1": "rgt,out1", + }, + }, + models={ + "coupler": coupler, + "waveguide": waveguide, + }, + ) + + mzi() + mzi(top={"length": 25.0}, btm={"length": 15.0}) + wl = jnp.linspace(1.51, 1.59, 1000) + mzi(wl=wl, top={"length": 25.0}, btm={"length": 15.0}) From b45f96e6895b4dea6d6a38f10f95e254af86152c Mon Sep 17 00:00:00 2001 From: Joaquin Matres <4514346+joamatab@users.noreply.github.com> Date: Sat, 16 May 2026 15:07:59 -0700 Subject: [PATCH 9/9] Remove leftover sax/backends/cuda.py (moved to src/sax/backends/). --- sax/backends/cuda.py | 200 ------------------------------------------- 1 file changed, 200 deletions(-) delete mode 100644 sax/backends/cuda.py diff --git a/sax/backends/cuda.py b/sax/backends/cuda.py deleted file mode 100644 index 448a84e8..00000000 --- a/sax/backends/cuda.py +++ /dev/null @@ -1,200 +0,0 @@ -""" SAX CUDA Backend """ - -from __future__ import annotations - -from typing import Any, Dict - -import cupy as cp -import cupyx -import cupyx.scipy.sparse.linalg -import jax.numpy as jnp -from natsort import natsorted - -from ..netlist import Component -from ..saxtypes import Model, SCoo, SDense, SType, scoo - - -def scoo_cupy(S): - Si, Sj, Sx, ports_map = scoo(S) - return cp.asarray(Si), cp.asarray(Sj), cp.asarray(Sx), ports_map - - -def solve_cuda(Ai, Aj, Ax, B): - """ - Custom solver using CuPy for sparse matrix solve. - - Args: - Ai (array): Row indices of non-zero values in the sparse matrix. - Aj (array): Column indices of non-zero values in the sparse matrix. - Ax (array): Non-zero values of the sparse matrix. - B (array): Right-hand side matrix to solve for. - - Returns: array: Solution matrix. - """ - results = [] - for Ax_mat in Ax: - # Create sparse matrix in COO format - A_coo = cupyx.scipy.sparse.coo_matrix((Ax_mat, (Ai, Aj))) - - # Convert to CSR format for solving - A_csr = A_coo.tocsr() - - # Solve the linear system - results.append(cupyx.scipy.sparse.linalg.spsolve(A_csr, B)) - - return cp.asarray(results) - - -def coo_mul_vec(Si, Sj, Sx, x): - """ - COO matrix-vector multiplication using CuPy. - - Args: - Si (array): Row indices of non-zero values in the sparse matrix. - Sj (array): Column indices of non-zero values in the sparse matrix. - Sx (array): Non-zero values of the sparse matrix. - x (array): Dense vector to multiply with the sparse matrix. - - Returns: - array: Resulting vector from the multiplication. - """ - results = [] - for Sx_mat, x_vec in zip(Sx, x): - # Create sparse matrix in COO format - S_coo = cupyx.scipy.sparse.coo_matrix((Sx_mat, (Si, Sj))) - - # Perform the matrix-vector multiplication - results.append(S_coo.dot(x_vec)) - - return cp.asarray(results) - - -def analyze_instances_cuda( - instances: Dict[str, Component], - models: Dict[str, Model], -) -> Dict[str, SCoo]: - instances, instances_old = {}, instances - for k, v in instances_old.items(): - if not isinstance(v, Component): - v = Component(**v) - instances[k] = v - model_names = set() - for i in instances.values(): - if i.info and "model" in i.info and isinstance(i.info["model"], str): - model_names.add(str(i.info["model"])) - else: - model_names.add(str(i.component)) - dummy_models = {k: scoo_cupy(models[k]()) for k in model_names} - dummy_instances = {} - for k, i in instances.items(): - if i.info and "model" in i.info and isinstance(i.info["model"], str): - dummy_instances[k] = dummy_models[str(i.info["model"])] - else: - dummy_instances[k] = dummy_models[str(i.component)] - return dummy_instances - - -def analyze_circuit_cuda( - analyzed_instances: Dict[str, SCoo], - connections: Dict[str, str], - ports: Dict[str, str], -) -> Any: - connections = {**connections, **{v: k for k, v in connections.items()}} - inverse_ports = {v: k for k, v in ports.items()} - port_map = {k: i for i, k in enumerate(ports)} - - idx, Si, Sj, instance_ports = 0, [], [], {} - for name, instance in analyzed_instances.items(): - si, sj, _, ports_map = scoo_cupy(instance) - Si.append(cp.asarray(si + idx)) - Sj.append(cp.asarray(sj + idx)) - instance_ports.update({f"{name},{p}": i + idx for p, i in ports_map.items()}) - idx += len(ports_map) - - n_col = idx - n_rhs = len(port_map) - - Si = cp.concatenate(Si, -1) - Sj = cp.concatenate(Sj, -1) - - Cmap = { - int(instance_ports[k]): int(instance_ports[v]) for k, v in connections.items() - } - Ci = cp.array(list(Cmap.keys()), dtype=cp.int32) - Cj = cp.array(list(Cmap.values()), dtype=cp.int32) - - Cextmap = { - int(instance_ports[k]): int(port_map[v]) for k, v in inverse_ports.items() - } - Cexti = cp.asarray(list(Cextmap.keys())) - Cextj = cp.asarray(list(Cextmap.values())) - Cext = cp.zeros((n_col, n_rhs), dtype=complex) - Cext[Cexti, Cextj] = 1.0 - - mask = Cj[None, :] == Si[:, None] - CSi = cp.broadcast_to(Ci[None, :], mask.shape)[mask] - - mask = (Cj[:, None] == Si[None, :]).any(0) - CSj = Sj[mask] - - Ii = Ij = cp.arange(n_col) - I_CSi = cp.concatenate([CSi, Ii], -1) - I_CSj = cp.concatenate([CSj, Ij], -1) - return ( - n_col, - mask, - Si, - Sj, - Cext, - Cexti, - Cextj, - I_CSi, - I_CSj, - tuple((k, v[1]) for k, v in analyzed_instances.items()), - tuple(port_map), - ) - - -def evaluate_circuit_cuda(analyzed: Any, instances: Dict[str, SType]) -> SDense: - ( - n_col, - mask, - Si, - Sj, - Cext, - Cexti, - Cextj, - I_CSi, - I_CSj, - dummy_pms, - port_map, - ) = analyzed - - idx = 0 - Sx = [] - batch_shape = () - for name, pm_ in dummy_pms: - _, _, sx, ports_map = scoo_cupy(instances[name]) - Sx.append(sx) - if len(sx.shape[:-1]) > len(batch_shape): - batch_shape = sx.shape[:-1] - idx += len(ports_map) - - Sx = cp.concatenate( - [cp.broadcast_to(sx, (*batch_shape, sx.shape[-1])) for sx in Sx], -1 - ) - CSx = Sx[..., mask] - Ix = cp.ones((*batch_shape, n_col)) - I_CSx = cp.concatenate([-CSx, Ix], -1) - - Sx = Sx.reshape(-1, Sx.shape[-1]) # n_lhs x N - I_CSx = I_CSx.reshape(-1, I_CSx.shape[-1]) # n_lhs x M - inv_I_CS_Cext = solve_cuda(I_CSi, I_CSj, I_CSx, Cext) - S_inv_I_CS_Cext = coo_mul_vec(Si, Sj, Sx, inv_I_CS_Cext) - - CextT_S_inv_I_CS_Cext = S_inv_I_CS_Cext[..., Cexti, :][..., :, Cextj] - - _, n, _ = CextT_S_inv_I_CS_Cext.shape - S = CextT_S_inv_I_CS_Cext.reshape(*batch_shape, n, n) - - return jnp.asarray(S), {p: i for i, p in enumerate(port_map)}