diff --git a/src/tsim/circuit.py b/src/tsim/circuit.py index ce49e2af..bab50331 100644 --- a/src/tsim/circuit.py +++ b/src/tsim/circuit.py @@ -797,6 +797,11 @@ def compile_detector_sampler( ) -> CompiledDetectorSampler: """Compile circuit into a detector sampler. + Connected components whose single output is deterministically given by + one f-variable are handled via a fast direct path (no compilation or + autoregressive sampling). Remaining components go through the full + compilation pipeline. + Args: strategy: Stabilizer rank decomposition strategy. Must be one of "cat5", "bss", "cutting". diff --git a/src/tsim/compile/pipeline.py b/src/tsim/compile/pipeline.py index d462bf41..dfe25846 100644 --- a/src/tsim/compile/pipeline.py +++ b/src/tsim/compile/pipeline.py @@ -7,14 +7,19 @@ import jax.numpy as jnp import pyzx_param as zx from pyzx_param.graph.base import BaseGraph -from pyzx_param.simulate import DecompositionStrategy from tsim.compile.compile import CompiledScalarGraphs, compile_scalar_graphs from tsim.compile.stabrank import find_stab -from tsim.core.graph import ConnectedComponent, connected_components, get_params +from tsim.core.graph import ( + ConnectedComponent, + classify_direct, + connected_components, + get_params, +) from tsim.core.types import CompiledComponent, CompiledProgram, SamplingGraph DecompositionMode = Literal["sequential", "joint"] +from pyzx_param.simulate import DecompositionStrategy def compile_program( @@ -52,24 +57,57 @@ def compile_program( f_indices_global = _get_f_indices(prepared.graph) num_outputs = prepared.num_outputs + direct_f_indices: list[int] = [] + direct_flips: list[bool] = [] + direct_output_order: list[int] = [] compiled_components: list[CompiledComponent] = [] - output_order: list[int] = [] + compiled_output_order: list[int] = [] sorted_components = sorted(components, key=lambda c: len(c.output_indices)) for component in sorted_components: - compiled = _compile_component( - component=component, - f_indices_global=f_indices_global, - mode=mode, - strategy=strategy, + result = classify_direct(component) + if result is not None: + f_idx, flip = result + direct_f_indices.append(f_idx) + direct_flips.append(flip) + direct_output_order.append(component.output_indices[0]) + else: + compiled = _compile_component( + component=component, + f_indices_global=f_indices_global, + mode=mode, + strategy=strategy, + ) + compiled_components.append(compiled) + compiled_output_order.extend(component.output_indices) + + # Sort direct entries by output index so that the concatenation layout + # in sample_program matches the original output order as closely as + # possible. When transform_error_basis also prioritises outputs, this + # often yields an identity permutation and avoids reindexing at sample time. + if direct_output_order: + order = sorted( + range(len(direct_output_order)), key=direct_output_order.__getitem__ ) - compiled_components.append(compiled) - output_order.extend(component.output_indices) + direct_f_indices = [direct_f_indices[i] for i in order] + direct_flips = [direct_flips[i] for i in order] + direct_output_order = [direct_output_order[i] for i in order] + + # output_order must match the concatenation layout in sample_program: + # [direct bits, compiled_0 outputs, compiled_1 outputs, ...] + output_order = jnp.array( + direct_output_order + compiled_output_order, dtype=jnp.int32 + ) + reindex = jnp.argsort(output_order) + is_identity = bool(jnp.all(reindex == jnp.arange(len(output_order)))) return CompiledProgram( components=tuple(compiled_components), - output_order=jnp.array(output_order, dtype=jnp.int32), + direct_f_indices=jnp.array(direct_f_indices, dtype=jnp.int32), + direct_flips=jnp.array(direct_flips, dtype=jnp.bool_), + output_order=output_order, + output_reindex=None if is_identity else reindex, num_outputs=num_outputs, num_f_params=len(f_indices_global), num_detectors=prepared.num_detectors, diff --git a/src/tsim/core/graph.py b/src/tsim/core/graph.py index 8cb638af..8a9f93d5 100644 --- a/src/tsim/core/graph.py +++ b/src/tsim/core/graph.py @@ -13,7 +13,7 @@ from pyzx_param.graph.graph import Graph from pyzx_param.graph.graph_s import GraphS from pyzx_param.graph.scalar import Scalar -from pyzx_param.utils import VertexType +from pyzx_param.utils import EdgeType, VertexType from tsim.core.instructions import GraphRepresentation from tsim.core.parse import parse_stim_circuit @@ -65,6 +65,65 @@ def connected_components(g: BaseGraph) -> list[ConnectedComponent]: return components +def classify_direct( + component: ConnectedComponent, +) -> tuple[int, bool] | None: + """Check if a component is directly determined by a single f-variable. + + A component qualifies when its graph consists of exactly two vertices — one + boundary output and one Z-spider — connected by a Hadamard edge, where the + Z-spider carries a single ``f`` parameter and a constant phase of either 0 + (no flip) or π (flip). + + Args: + component: A connected component to classify. + + Returns: + ``(f_index, flip)`` if the fast path applies, otherwise ``None``. + + """ + graph = component.graph + outputs = list(graph.outputs()) + if len(outputs) != 1: + return None + + vertices = list(graph.vertices()) + if len(vertices) != 2: + return None + + v_out = outputs[0] + neighbors = list(graph.neighbors(v_out)) + if len(neighbors) != 1: + return None + + v_det = neighbors[0] + if graph.type(v_det) != VertexType.Z: + return None + if graph.edge_type(graph.edge(v_out, v_det)) != EdgeType.HADAMARD: + return None + + params = graph.get_params(v_det) + if len(params) != 1: + return None + f_param = next(iter(params)) + if not f_param.startswith("f"): + return None + + all_graph_params = get_params(graph) + if all_graph_params != {f_param}: + return None + + phase = graph.phase(v_det) + if phase == 0: + flip = False + elif phase == Fraction(1, 1): + flip = True + else: + return None + + return int(f_param[1:]), flip + + def _collect_vertices( g: BaseGraph, start: Any, @@ -274,9 +333,27 @@ def transform_error_basis( then f0 = e1 XOR e3. """ - parametrized_vertices = [ - v for v in g.vertices() if v in g._phaseVars and g._phaseVars[v] + # Prioritize output-connected detector vertices so that f0, f1, ... + # are assigned in output order. This maximises the chance that the + # direct-component fast path produces an identity permutation, avoiding + # a column reindex at sample time. + output_detectors = [] + for v_out in g.outputs(): + neighbors = list(g.neighbors(v_out)) + if ( + len(neighbors) == 1 + and neighbors[0] in g._phaseVars + and g._phaseVars[neighbors[0]] + ): + output_detectors.append(neighbors[0]) + + output_det_set = set(output_detectors) + rest = [ + v + for v in g.vertices() + if v not in output_det_set and v in g._phaseVars and g._phaseVars[v] ] + parametrized_vertices = output_detectors + rest if not parametrized_vertices: g.scalar = Scalar() diff --git a/src/tsim/core/types.py b/src/tsim/core/types.py index c8f1cd81..3fe18691 100644 --- a/src/tsim/core/types.py +++ b/src/tsim/core/types.py @@ -86,8 +86,13 @@ class CompiledProgram: Attributes: components: The compiled components, sorted by number of outputs. - output_order: Array for reordering component outputs to final order. - final_samples = combined[:, np.argsort(output_order)] + direct_f_indices: Precomputed f-parameter indices for direct components. + direct_flips: Precomputed flip flags for direct components. + output_order: Maps concatenated position to original output index. + The first ``len(direct_f_indices)`` entries correspond to direct + components; the remainder to compiled components. + output_reindex: Precomputed ``argsort(output_order)`` permutation, + or ``None`` when the outputs are already in order. num_outputs: Total number of outputs across all components. num_f_params: Total number of f-parameters. num_detectors: Number of detector outputs (for detector sampling). @@ -95,7 +100,10 @@ class CompiledProgram: """ components: tuple[CompiledComponent, ...] + direct_f_indices: Array + direct_flips: Array output_order: Array + output_reindex: Array | None num_outputs: int num_f_params: int num_detectors: int diff --git a/src/tsim/sampler.py b/src/tsim/sampler.py index 54eb3441..9ba16ee8 100644 --- a/src/tsim/sampler.py +++ b/src/tsim/sampler.py @@ -130,8 +130,19 @@ def sample_program( match the original output indices. """ + batch_size = f_params.shape[0] results: list[jax.Array] = [] + if program.num_outputs == 0: + return jnp.zeros((batch_size, 0), dtype=jnp.bool_) + + if len(program.direct_f_indices) > 0: + direct_bits = ( + f_params[:, program.direct_f_indices].astype(jnp.bool_) + ^ program.direct_flips + ) + results.append(direct_bits) + for component in program.components: samples, key, max_norm_deviation = sample_component(component, f_params, key) if np.isclose(max_norm_deviation, 1): @@ -149,7 +160,9 @@ def sample_program( results.append(samples) combined = jnp.concatenate(results, axis=1) - return combined[:, jnp.argsort(program.output_order)] + if program.output_reindex is not None: + combined = combined[:, program.output_reindex] + return combined class _CompiledSamplerBase: @@ -193,6 +206,20 @@ def __init__( self.circuit = circuit self._num_detectors = prepared.num_detectors + # Pre-cache numpy arrays for the direct fast path so we don't + # convert from JAX on every sample call. + prog = self._program + n_direct = len(prog.direct_f_indices) + self._direct_f_indices = np.asarray(prog.direct_f_indices) + self._direct_flips = np.asarray(prog.direct_flips, dtype=np.bool_) + self._direct_reindex = ( + np.asarray(prog.output_reindex) if prog.output_reindex is not None else None + ) + self._direct_has_flips = bool(np.any(self._direct_flips)) + self._direct_contiguous = n_direct > 0 and np.array_equal( + self._direct_f_indices, np.arange(n_direct) + ) + def _peak_bytes_per_sample(self) -> int: """Estimate peak device memory per sample from compiled program structure.""" peak = 0 @@ -258,6 +285,9 @@ def _sample_batches( Samples array, or (samples, reference) tuple when compute_reference=True. """ + if not self._program.components and not compute_reference: + return self._sample_direct(shots) + if batch_size is None: max_batch_size = self._estimate_batch_size() num_batches = max(1, ceil(shots / max_batch_size)) @@ -295,8 +325,24 @@ def _sample_batches( return result, reference return result + def _sample_direct(self, shots: int) -> np.ndarray: + """Fast path when all components are direct (pure numpy, no JAX).""" + f_params = self._channel_sampler.sample(shots) + n = len(self._direct_f_indices) + if self._direct_contiguous: + result = f_params[:, :n] if n < f_params.shape[1] else f_params + else: + result = f_params[:, self._direct_f_indices] + if self._direct_has_flips: + result = result ^ self._direct_flips + if self._direct_reindex is not None: + result = result[:, self._direct_reindex] + return result.astype(np.bool_) + def __repr__(self) -> str: """Return a string representation with compilation statistics.""" + n_direct = len(self._program.direct_f_indices) + c_graphs = [] c_params = [] c_a_terms = [] @@ -335,11 +381,13 @@ def _format_bytes(n: int) -> str: error_channel_bits = sum( channel.num_bits for channel in self._channel_sampler.channels ) + max_outputs = int(np.max(num_outputs)) if num_outputs else 0 return ( - f"{type(self).__name__}({np.sum(c_graphs)} graphs, " + f"{type(self).__name__}({n_direct} direct, " + f"{np.sum(c_graphs)} graphs, " f"{error_channel_bits} error channel bits, " - f"{np.max(num_outputs)} outputs for largest cc, " + f"{max_outputs} outputs for largest cc, " f"≤ {np.max(c_params) if c_params else 0} parameters, {np.sum(c_a_terms)} A terms, " f"{np.sum(c_b_terms)} B terms, " f"{np.sum(c_c_terms)} C terms, {np.sum(c_d_terms)} D terms, " @@ -590,6 +638,15 @@ def probability_of(self, state: np.ndarray, *, batch_size: int) -> np.ndarray: p_norm = jnp.ones(batch_size) p_joint = jnp.ones(batch_size) + if len(self._program.direct_f_indices) > 0: + bits = ( + f_samples[:, self._program.direct_f_indices].astype(jnp.bool_) + ^ self._program.direct_flips + ) + n_direct = len(self._program.direct_f_indices) + targets = state[self._program.output_order[:n_direct]] + p_joint = p_joint * (bits == targets).all(axis=1) + for component in self._program.components: assert len(component.compiled_scalar_graphs) == 2 diff --git a/test/integration/test_sampler.py b/test/integration/test_sampler.py index 1671c9ee..9f1b5fc2 100644 --- a/test/integration/test_sampler.py +++ b/test/integration/test_sampler.py @@ -120,7 +120,10 @@ def test_sample_program_raises_on_component_norm_deviation(monkeypatch): ) program = CompiledProgram( components=components, + direct_f_indices=jnp.array([], dtype=jnp.int32), + direct_flips=jnp.array([], dtype=jnp.bool_), output_order=jnp.array([0, 1]), + output_reindex=None, num_outputs=2, num_f_params=0, num_detectors=0, @@ -518,3 +521,17 @@ def test_compare_to_statevector_simulator_and_pyzx_tensor_with_arbitrary_rotatio assert np.allclose( tsim_state_vector, pyzx_state_vector, atol=tol, rtol=tol ), f"Seed: {seed}" + + +def test_no_detectors_with_reference_sample(): + """Detector sampler on a circuit with no detectors returns empty arrays.""" + c = Circuit("R 0\nH 0\nM 0") + sampler = c.compile_detector_sampler() + + # Without reference sample + d = sampler.sample(10) + assert d.shape == (10, 0) + + # With reference sample — previously crashed with empty concatenation + d_ref = sampler.sample(10, use_detector_reference_sample=True) + assert d_ref.shape == (10, 0)