Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/tsim/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,11 @@ def compile_detector_sampler(
) -> CompiledDetectorSampler:
"""Compile circuit into a detector sampler.

Connected components whose single output is deterministically given by
one f-variable are handled via a fast direct path (no compilation or
autoregressive sampling). Remaining components go through the full
compilation pipeline.

Args:
strategy: Stabilizer rank decomposition strategy.
Must be one of "cat5", "bss", "cutting".
Expand Down
60 changes: 49 additions & 11 deletions src/tsim/compile/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import jax.numpy as jnp
import pyzx_param as zx
from pyzx_param.graph.base import BaseGraph
from pyzx_param.simulate import DecompositionStrategy

from tsim.compile.compile import CompiledScalarGraphs, compile_scalar_graphs
from tsim.compile.stabrank import find_stab
from tsim.core.graph import ConnectedComponent, connected_components, get_params
from tsim.core.graph import (
ConnectedComponent,
classify_direct,
connected_components,
get_params,
)
from tsim.core.types import CompiledComponent, CompiledProgram, SamplingGraph

DecompositionMode = Literal["sequential", "joint"]
from pyzx_param.simulate import DecompositionStrategy


def compile_program(
Expand Down Expand Up @@ -52,24 +57,57 @@ def compile_program(
f_indices_global = _get_f_indices(prepared.graph)
num_outputs = prepared.num_outputs

direct_f_indices: list[int] = []
direct_flips: list[bool] = []
direct_output_order: list[int] = []
compiled_components: list[CompiledComponent] = []
output_order: list[int] = []
compiled_output_order: list[int] = []

sorted_components = sorted(components, key=lambda c: len(c.output_indices))

for component in sorted_components:
compiled = _compile_component(
component=component,
f_indices_global=f_indices_global,
mode=mode,
strategy=strategy,
result = classify_direct(component)
if result is not None:
f_idx, flip = result
direct_f_indices.append(f_idx)
direct_flips.append(flip)
direct_output_order.append(component.output_indices[0])
else:
compiled = _compile_component(
component=component,
f_indices_global=f_indices_global,
mode=mode,
strategy=strategy,
)
compiled_components.append(compiled)
compiled_output_order.extend(component.output_indices)

# Sort direct entries by output index so that the concatenation layout
# in sample_program matches the original output order as closely as
# possible. When transform_error_basis also prioritises outputs, this
# often yields an identity permutation and avoids reindexing at sample time.
if direct_output_order:
order = sorted(
range(len(direct_output_order)), key=direct_output_order.__getitem__
)
compiled_components.append(compiled)
output_order.extend(component.output_indices)
direct_f_indices = [direct_f_indices[i] for i in order]
direct_flips = [direct_flips[i] for i in order]
direct_output_order = [direct_output_order[i] for i in order]

# output_order must match the concatenation layout in sample_program:
# [direct bits, compiled_0 outputs, compiled_1 outputs, ...]
output_order = jnp.array(
direct_output_order + compiled_output_order, dtype=jnp.int32
)
reindex = jnp.argsort(output_order)
is_identity = bool(jnp.all(reindex == jnp.arange(len(output_order))))

return CompiledProgram(
components=tuple(compiled_components),
output_order=jnp.array(output_order, dtype=jnp.int32),
direct_f_indices=jnp.array(direct_f_indices, dtype=jnp.int32),
direct_flips=jnp.array(direct_flips, dtype=jnp.bool_),
output_order=output_order,
output_reindex=None if is_identity else reindex,
num_outputs=num_outputs,
num_f_params=len(f_indices_global),
num_detectors=prepared.num_detectors,
Expand Down
83 changes: 80 additions & 3 deletions src/tsim/core/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pyzx_param.graph.graph import Graph
from pyzx_param.graph.graph_s import GraphS
from pyzx_param.graph.scalar import Scalar
from pyzx_param.utils import VertexType
from pyzx_param.utils import EdgeType, VertexType

from tsim.core.instructions import GraphRepresentation
from tsim.core.parse import parse_stim_circuit
Expand Down Expand Up @@ -65,6 +65,65 @@ def connected_components(g: BaseGraph) -> list[ConnectedComponent]:
return components


def classify_direct(
component: ConnectedComponent,
) -> tuple[int, bool] | None:
"""Check if a component is directly determined by a single f-variable.

A component qualifies when its graph consists of exactly two vertices — one
boundary output and one Z-spider — connected by a Hadamard edge, where the
Z-spider carries a single ``f`` parameter and a constant phase of either 0
(no flip) or π (flip).

Args:
component: A connected component to classify.

Returns:
``(f_index, flip)`` if the fast path applies, otherwise ``None``.

"""
graph = component.graph
outputs = list(graph.outputs())
if len(outputs) != 1:
return None

vertices = list(graph.vertices())
if len(vertices) != 2:
return None

v_out = outputs[0]
neighbors = list(graph.neighbors(v_out))
if len(neighbors) != 1:
return None

v_det = neighbors[0]
if graph.type(v_det) != VertexType.Z:
return None
if graph.edge_type(graph.edge(v_out, v_det)) != EdgeType.HADAMARD:
return None

params = graph.get_params(v_det)
if len(params) != 1:
return None
f_param = next(iter(params))
if not f_param.startswith("f"):
return None

all_graph_params = get_params(graph)
if all_graph_params != {f_param}:
return None

phase = graph.phase(v_det)
if phase == 0:
flip = False
elif phase == Fraction(1, 1):
flip = True
else:
return None

return int(f_param[1:]), flip


def _collect_vertices(
g: BaseGraph,
start: Any,
Expand Down Expand Up @@ -274,9 +333,27 @@ def transform_error_basis(
then f0 = e1 XOR e3.

"""
parametrized_vertices = [
v for v in g.vertices() if v in g._phaseVars and g._phaseVars[v]
# Prioritize output-connected detector vertices so that f0, f1, ...
# are assigned in output order. This maximises the chance that the
# direct-component fast path produces an identity permutation, avoiding
# a column reindex at sample time.
output_detectors = []
for v_out in g.outputs():
neighbors = list(g.neighbors(v_out))
if (
len(neighbors) == 1
and neighbors[0] in g._phaseVars
and g._phaseVars[neighbors[0]]
):
output_detectors.append(neighbors[0])

output_det_set = set(output_detectors)
rest = [
v
for v in g.vertices()
if v not in output_det_set and v in g._phaseVars and g._phaseVars[v]
]
parametrized_vertices = output_detectors + rest

if not parametrized_vertices:
g.scalar = Scalar()
Expand Down
12 changes: 10 additions & 2 deletions src/tsim/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,24 @@ class CompiledProgram:

Attributes:
components: The compiled components, sorted by number of outputs.
output_order: Array for reordering component outputs to final order.
final_samples = combined[:, np.argsort(output_order)]
direct_f_indices: Precomputed f-parameter indices for direct components.
direct_flips: Precomputed flip flags for direct components.
output_order: Maps concatenated position to original output index.
The first ``len(direct_f_indices)`` entries correspond to direct
components; the remainder to compiled components.
output_reindex: Precomputed ``argsort(output_order)`` permutation,
or ``None`` when the outputs are already in order.
num_outputs: Total number of outputs across all components.
num_f_params: Total number of f-parameters.
num_detectors: Number of detector outputs (for detector sampling).

"""

components: tuple[CompiledComponent, ...]
direct_f_indices: Array
direct_flips: Array
output_order: Array
output_reindex: Array | None
num_outputs: int
num_f_params: int
num_detectors: int
63 changes: 60 additions & 3 deletions src/tsim/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,19 @@ def sample_program(
match the original output indices.

"""
batch_size = f_params.shape[0]
results: list[jax.Array] = []

if program.num_outputs == 0:
return jnp.zeros((batch_size, 0), dtype=jnp.bool_)

if len(program.direct_f_indices) > 0:
direct_bits = (
f_params[:, program.direct_f_indices].astype(jnp.bool_)
^ program.direct_flips
)
results.append(direct_bits)

for component in program.components:
samples, key, max_norm_deviation = sample_component(component, f_params, key)
if np.isclose(max_norm_deviation, 1):
Expand All @@ -149,7 +160,9 @@ def sample_program(
results.append(samples)

combined = jnp.concatenate(results, axis=1)
return combined[:, jnp.argsort(program.output_order)]
if program.output_reindex is not None:
combined = combined[:, program.output_reindex]
return combined


class _CompiledSamplerBase:
Expand Down Expand Up @@ -193,6 +206,20 @@ def __init__(
self.circuit = circuit
self._num_detectors = prepared.num_detectors

# Pre-cache numpy arrays for the direct fast path so we don't
# convert from JAX on every sample call.
prog = self._program
n_direct = len(prog.direct_f_indices)
self._direct_f_indices = np.asarray(prog.direct_f_indices)
self._direct_flips = np.asarray(prog.direct_flips, dtype=np.bool_)
self._direct_reindex = (
np.asarray(prog.output_reindex) if prog.output_reindex is not None else None
)
self._direct_has_flips = bool(np.any(self._direct_flips))
self._direct_contiguous = n_direct > 0 and np.array_equal(
self._direct_f_indices, np.arange(n_direct)
)

def _peak_bytes_per_sample(self) -> int:
"""Estimate peak device memory per sample from compiled program structure."""
peak = 0
Expand Down Expand Up @@ -258,6 +285,9 @@ def _sample_batches(
Samples array, or (samples, reference) tuple when compute_reference=True.

"""
if not self._program.components and not compute_reference:
return self._sample_direct(shots)

if batch_size is None:
max_batch_size = self._estimate_batch_size()
num_batches = max(1, ceil(shots / max_batch_size))
Expand Down Expand Up @@ -295,8 +325,24 @@ def _sample_batches(
return result, reference
return result

def _sample_direct(self, shots: int) -> np.ndarray:
"""Fast path when all components are direct (pure numpy, no JAX)."""
f_params = self._channel_sampler.sample(shots)
n = len(self._direct_f_indices)
if self._direct_contiguous:
result = f_params[:, :n] if n < f_params.shape[1] else f_params
else:
result = f_params[:, self._direct_f_indices]
if self._direct_has_flips:
result = result ^ self._direct_flips
if self._direct_reindex is not None:
result = result[:, self._direct_reindex]
return result.astype(np.bool_)

def __repr__(self) -> str:
"""Return a string representation with compilation statistics."""
n_direct = len(self._program.direct_f_indices)

c_graphs = []
c_params = []
c_a_terms = []
Expand Down Expand Up @@ -335,11 +381,13 @@ def _format_bytes(n: int) -> str:
error_channel_bits = sum(
channel.num_bits for channel in self._channel_sampler.channels
)
max_outputs = int(np.max(num_outputs)) if num_outputs else 0

return (
f"{type(self).__name__}({np.sum(c_graphs)} graphs, "
f"{type(self).__name__}({n_direct} direct, "
f"{np.sum(c_graphs)} graphs, "
f"{error_channel_bits} error channel bits, "
f"{np.max(num_outputs)} outputs for largest cc, "
f"{max_outputs} outputs for largest cc, "
f"≤ {np.max(c_params) if c_params else 0} parameters, {np.sum(c_a_terms)} A terms, "
f"{np.sum(c_b_terms)} B terms, "
f"{np.sum(c_c_terms)} C terms, {np.sum(c_d_terms)} D terms, "
Expand Down Expand Up @@ -590,6 +638,15 @@ def probability_of(self, state: np.ndarray, *, batch_size: int) -> np.ndarray:
p_norm = jnp.ones(batch_size)
p_joint = jnp.ones(batch_size)

if len(self._program.direct_f_indices) > 0:
bits = (
f_samples[:, self._program.direct_f_indices].astype(jnp.bool_)
^ self._program.direct_flips
)
n_direct = len(self._program.direct_f_indices)
targets = state[self._program.output_order[:n_direct]]
p_joint = p_joint * (bits == targets).all(axis=1)

for component in self._program.components:
assert len(component.compiled_scalar_graphs) == 2

Expand Down
17 changes: 17 additions & 0 deletions test/integration/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ def test_sample_program_raises_on_component_norm_deviation(monkeypatch):
)
program = CompiledProgram(
components=components,
direct_f_indices=jnp.array([], dtype=jnp.int32),
direct_flips=jnp.array([], dtype=jnp.bool_),
output_order=jnp.array([0, 1]),
output_reindex=None,
num_outputs=2,
num_f_params=0,
num_detectors=0,
Expand Down Expand Up @@ -518,3 +521,17 @@ def test_compare_to_statevector_simulator_and_pyzx_tensor_with_arbitrary_rotatio
assert np.allclose(
tsim_state_vector, pyzx_state_vector, atol=tol, rtol=tol
), f"Seed: {seed}"


def test_no_detectors_with_reference_sample():
"""Detector sampler on a circuit with no detectors returns empty arrays."""
c = Circuit("R 0\nH 0\nM 0")
sampler = c.compile_detector_sampler()

# Without reference sample
d = sampler.sample(10)
assert d.shape == (10, 0)

# With reference sample — previously crashed with empty concatenation
d_ref = sampler.sample(10, use_detector_reference_sample=True)
assert d_ref.shape == (10, 0)
Loading