Skip to content
28 changes: 28 additions & 0 deletions glue/stimflow/src/stimflow/_chunk/_chunk_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,3 +311,31 @@ def test_partial_observable_include_memory_experiment():
assert np.count_nonzero(dets) == 0
assert 200 <= np.count_nonzero(obs[:, 0]) <= 300 # Some measurements included in partial X obs
assert np.count_nonzero(obs[:, 1]) == 0 # No measurements included in partial Z obs


def test_auto_obs():
builder = stimflow.ChunkBuilder(allowed_qubits=[0, 1])
builder.append("R", [0])
builder.add_flow(
start="auto",
end=stimflow.PauliMap({0: 'Z', 1: 'Z'}, obs_name='test'),
)
builder.add_flow(
start=stimflow.PauliMap({1: 'Z'}, obs_name='test2'),
end="auto",
)
chunk = builder.finish_chunk()
chunk.verify()

assert chunk.flows == (
stimflow.Flow(
start=stimflow.PauliMap({1: 'Z'}, obs_name='test'),
end=stimflow.PauliMap({0: 'Z', 1: 'Z'}, obs_name='test'),
center=0.5,
),
stimflow.Flow(
start=stimflow.PauliMap({1: 'Z'}, obs_name='test2'),
end=stimflow.PauliMap({1: 'Z'}, obs_name='test2'),
center=1,
),
)
6 changes: 3 additions & 3 deletions glue/stimflow/src/stimflow/_chunk/_flow_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _solve_auto_flow_starts(
except ValueError:
failure_out.append(flow)
continue
start = PauliMap({i2q[q]: "_XYZ"[stim_start[q]] for q in stim_start.pauli_indices()})
start = PauliMap({i2q[q]: "_XYZ"[stim_start[q]] for q in stim_start.pauli_indices()}, obs_name=flow.end.obs_name)
new_flows.append(flow.with_edits(start=start))

return new_flows
Expand All @@ -43,7 +43,7 @@ def _solve_auto_flow_ends(
failure_out: list[Flow],
) -> list[Flow]:

num_qubits = circuit.num_qubits
num_qubits = max(q2i.values(), default=-1) + 1
i2q = {i: q for q, i in q2i.items()}

new_flows = []
Expand All @@ -54,7 +54,7 @@ def _solve_auto_flow_ends(
except ValueError:
failure_out.append(flow)
continue
end = PauliMap({i2q[q]: "_XYZ"[stim_end[q]] for q in stim_end.pauli_indices()})
end = PauliMap({i2q[q]: "_XYZ"[stim_end[q]] for q in stim_end.pauli_indices()}, obs_name=flow.start.obs_name)
new_flows.append(flow.with_edits(end=end))

return new_flows
Expand Down
2 changes: 1 addition & 1 deletion glue/stimflow/src/stimflow/_chunk/_stabilizer_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def x_basis_subset(self) -> StabilizerCode:

def z_basis_subset(self) -> StabilizerCode:
return StabilizerCode(
stabilizers=self.stabilizers.with_only_x_tiles(),
stabilizers=self.stabilizers.with_only_z_tiles(),
logicals=self.list_pure_basis_observables("Z"),
)

Expand Down
23 changes: 15 additions & 8 deletions glue/stimflow/src/stimflow/_core/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,21 @@ def __str__(self) -> str:
return result

def __repr__(self):
return (
f"stimflow.Flow(start={self.start!r}, "
f"end={self.end!r}, "
f"measurement_indices={self.measurement_indices!r}, "
f"flags={self.flags!r}, "
f"center={self.center!r}, "
f"sign={self.sign!r}"
)
lines = ["stimflow.Flow("]
if self.start:
lines.append(f"start={self.start!r},")
if self.end:
lines.append(f"end={self.end!r},")
if self.measurement_indices:
lines.append(f"measurement_indices={self.measurement_indices!r},")
if self.flags:
lines.append(f"flags={self.flags!r},")
if self.center is not None:
lines.append(f"center={self.center!r},")
if self.sign is not None:
lines.append(f"sign={self.sign!r},")
lines.append(")")
return '\n'.join(lines)

def with_xz_flipped(self) -> Flow:
return self.with_edits(start=self.start.with_xz_flipped(), end=self.end.with_xz_flipped())
Expand Down
2 changes: 0 additions & 2 deletions src/stim/simulators/frame_simulator_pybind_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,6 @@ def test_generate_bernoulli_samples():
assert np.all(v == 0)

sim.generate_bernoulli_samples(256 - 101, p=1, bit_packed=True, out=v[1:-11])
for k in v:
print(k)
assert np.all(v[1:-12] == 0xFF)
assert v[-12] == 7
assert np.all(v[-11:] == 0)
Expand Down
Loading