Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
90ec2cf
🐛 Fix RL Training and Improve Structure (#573)
flowerthrower Mar 11, 2026
4a0ed8d
Merge commit '90ec2cf' into fix-RL-training-bug
flowerthrower May 11, 2026
68eb338
🎨 improve seed and training defaults
flowerthrower May 11, 2026
f9de637
🎨 adjust test step limits
flowerthrower May 11, 2026
55e5e08
⏪ revert unrelated changes
flowerthrower May 11, 2026
ba9042d
🎨 pre-commit fixes
pre-commit-ci[bot] May 11, 2026
dca4827
🎨 improve comments
flowerthrower May 11, 2026
1e523e1
✅ fix synthesis size limit for bqskit passes
flowerthrower May 11, 2026
6d6487a
🎨 pre-commit fixes
pre-commit-ci[bot] May 11, 2026
7a300a2
🎨 reduce test training overhead
flowerthrower May 12, 2026
d64a97f
🎨 add comments
flowerthrower May 12, 2026
9f2697e
🎨 reduce number of training steps
flowerthrower May 12, 2026
010fa68
🎨 add changelog entry
flowerthrower May 12, 2026
788ec25
Merge remote-tracking branch 'origin/main' into fix-RL-training-bug
flowerthrower May 29, 2026
a474a8f
🎨 imporve error reporting
flowerthrower May 29, 2026
51b20af
✅ improve coverage
flowerthrower May 29, 2026
7e9a369
✅ fix test for qiskit<2
flowerthrower Jun 1, 2026
0694749
🎨 pre-commit fixes
pre-commit-ci[bot] Jun 1, 2026
ce7491a
Merge branch 'main' into fix-RL-training-bug
flowerthrower Jun 1, 2026
71d9be3
🎨 enable random training
flowerthrower Jun 3, 2026
9353191
🎨 imporve comments
flowerthrower Jun 3, 2026
3c1b076
🎨 make random seed default for training
flowerthrower Jun 3, 2026
95af4a3
🎨 remove redundant imports
flowerthrower Jun 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel

### Changed

- 🎨 Improve the RL state machine logic ([#677]) ([**@flowerthrower**])
- 🐛 Support BQSKit conversion of IQM's native `r` gate ([#679]) ([**@flowerthrower**])
- 🔧 Replace `mypy` with `ty` ([#572]) ([**@denialhaag**])
- 🐛 Fix instruction duration unit in estimated success probability calculation ([#445]) ([**@Shaobo-Zhou**])
Expand Down Expand Up @@ -47,6 +48,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool

<!-- PR links -->

[#677]: https://github.com/munich-quantum-toolkit/predictor/pull/677
[#679]: https://github.com/munich-quantum-toolkit/predictor/pull/679
[#572]: https://github.com/munich-quantum-toolkit/predictor/pull/572
[#489]: https://github.com/munich-quantum-toolkit/predictor/pull/489
Expand Down
51 changes: 16 additions & 35 deletions src/mqt/predictor/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

if TYPE_CHECKING:
from qiskit import QuantumCircuit
from qiskit.circuit import QuantumRegister, Qubit
from qiskit.transpiler import Target
from sklearn.ensemble import RandomForestRegressor

Expand Down Expand Up @@ -62,44 +61,22 @@ def expected_fidelity(qc: QuantumCircuit, device: Target, precision: int = 10) -

if gate_type != "barrier":
assert len(qargs) in [1, 2]
first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0)
first_qubit_idx = qc.find_bit(qargs[0]).index

if len(qargs) == 1:
specific_fidelity = 1 - device[gate_type][first_qubit_idx,].error
else:
second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1)
specific_fidelity = 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error

second_qubit_idx = qc.find_bit(qargs[1]).index
try:
specific_fidelity = 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error
except KeyError:
msg = f"Error rate for gate {gate_type} on qubits {first_qubit_idx} and {second_qubit_idx} not found in device properties."
raise KeyError(msg) from None
res *= specific_fidelity

return float(np.round(res, precision).item())


def calc_qubit_index(qargs: list[Qubit], qregs: list[QuantumRegister], index: int) -> int:
"""Calculates the global qubit index for a given quantum circuit and qubit index.

Arguments:
qargs: The qubits of the quantum circuit.
qregs: The quantum registers of the quantum circuit.
index: The index of the qubit in the qargs list.

Returns:
The global qubit index of the given qubit in the quantum circuit.

Raises:
ValueError: If the qubit index is not found in the quantum registers.
"""
offset = 0
for reg in qregs:
if qargs[index] not in reg:
offset += reg.size
else:
qubit_index: int = offset + reg.index(qargs[index])
return qubit_index
error_msg = f"Global qubit index for local qubit {index} index not found."
raise ValueError(error_msg)


def estimated_success_probability(qc: QuantumCircuit, device: Target, precision: int = 10) -> float:
"""Calculates the estimated success probability of a given quantum circuit on a given device.

Expand All @@ -125,7 +102,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
if gate_type == "barrier" or gate_type == "id":
continue
assert len(qargs) in (1, 2)
first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0)
first_qubit_idx = qc.find_bit(qargs[0]).index
active_qubits.add(first_qubit_idx)

if len(qargs) == 1: # single-qubit gate
Expand All @@ -140,7 +117,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
))
exec_time_per_qubit[first_qubit_idx] += duration
else: # multi-qubit gate
second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1)
second_qubit_idx = qc.find_bit(qargs[1]).index
active_qubits.add(second_qubit_idx)
duration = device[gate_type][first_qubit_idx, second_qubit_idx].duration
op_times.append((gate_type, [first_qubit_idx, second_qubit_idx], duration, "s"))
Expand Down Expand Up @@ -191,7 +168,7 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
continue

assert len(qargs) in (1, 2)
first_qubit_idx = calc_qubit_index(qargs, qc.qregs, 0)
first_qubit_idx = scheduled_circ.find_bit(qargs[0]).index

if len(qargs) == 1:
if gate_type == "measure":
Expand All @@ -213,8 +190,12 @@ def estimated_success_probability(qc: QuantumCircuit, device: Target, precision:
continue
res *= 1 - device[gate_type][first_qubit_idx,].error
else:
second_qubit_idx = calc_qubit_index(qargs, qc.qregs, 1)
res *= 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error
second_qubit_idx = scheduled_circ.find_bit(qargs[1]).index
try:
res *= 1 - device[gate_type][first_qubit_idx, second_qubit_idx].error
except KeyError:
msg = f"Error rate for gate {gate_type} on qubits {first_qubit_idx} and {second_qubit_idx} not found in device properties."
raise KeyError(msg) from None

if qiskit_version >= "2.0.0":
for i in range(device.num_qubits):
Expand Down
10 changes: 5 additions & 5 deletions src/mqt/predictor/rl/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from __future__ import annotations

import os
import sys
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -90,6 +89,7 @@
from qiskit.passmanager.base_tasks import Task

TaskList = list[Task | TketBasePass | PreProcessTKETRoutingAfterQiskitLayout]
from qiskit.passmanager import PropertySet


class CompilationOrigin(str, Enum):
Expand Down Expand Up @@ -146,7 +146,7 @@ class DeviceDependentAction(Action):
Callable[..., tuple[Any, ...] | Circuit],
]
)
do_while: Callable[[dict[str, Circuit]], bool] | None = None
do_while: Callable[[PropertySet], bool] | None = None


# Registry of actions
Expand Down Expand Up @@ -332,7 +332,7 @@ def remove_action(name: str) -> None:
circuit,
optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2,
synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8,
max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" else 3,
max_synthesis_size=3,
seed=10,
num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1,
),
Expand Down Expand Up @@ -431,7 +431,7 @@ def remove_action(name: str) -> None:
with_mapping=True,
optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2,
synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8,
max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" and sys.platform != "linux" else 3,
max_synthesis_size=3,
seed=10,
num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1,
)
Expand Down Expand Up @@ -461,7 +461,7 @@ def remove_action(name: str) -> None:
model=MachineModel(bqskit_circuit.num_qudits, gate_set=get_bqskit_native_gates(device)),
optimization_level=1 if os.getenv("GITHUB_ACTIONS") == "true" else 2,
synthesis_epsilon=1e-1 if os.getenv("GITHUB_ACTIONS") == "true" else 1e-8,
max_synthesis_size=2 if os.getenv("GITHUB_ACTIONS") == "true" and sys.platform != "linux" else 3,
max_synthesis_size=3,
seed=10,
num_workers=1 if os.getenv("GITHUB_ACTIONS") == "true" else -1,
)
Expand Down
12 changes: 9 additions & 3 deletions src/mqt/predictor/rl/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,24 @@ def train_model(
timesteps: int = 1000,
verbose: int = 2,
test: bool = False,
seed: int | None = None,
) -> None:
"""Trains all models for the given reward functions and device.

Arguments:
timesteps: The number of timesteps to train the model. Defaults to 1000.
verbose: The verbosity level. Defaults to 2.
test: Whether to train the model for testing purposes. Defaults to False.
seed: The random seed to use for reproducible training. Set to None to use true randomness.
Defaults to None.
"""
if seed is not None:
set_random_seed(seed)
if test:
set_random_seed(0) # for reproducibility
n_steps = 10
# minimum training overhead
n_steps = max(timesteps, 2)
n_epochs = 1
batch_size = 10
batch_size = n_steps
progress_bar = False
else:
# default PPO values
Expand All @@ -120,6 +125,7 @@ def train_model(
n_steps=n_steps,
batch_size=batch_size,
n_epochs=n_epochs,
seed=seed,
)
# Training Loop: In each iteration, the agent collects n_steps steps (rollout),
# updates the policy for n_epochs, and then repeats the process until total_timesteps steps have been taken.
Expand Down
Loading
Loading