Skip to content

Commit 894e508

Browse files
committed
linted
Signed-off-by: kvmto <kmato@nvidia.com>
1 parent 5a5da42 commit 894e508

2 files changed

Lines changed: 7 additions & 10 deletions

File tree

code/qec/dem_sampling.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,10 @@ def _reset_sampler_cache() -> None:
8484

8585

8686
def dem_sampling(
87-
H: torch.Tensor, p: torch.Tensor, batch_size: int, device_id: int | None = None
87+
H: torch.Tensor,
88+
p: torch.Tensor,
89+
batch_size: int,
90+
device_id: int | None = None
8891
) -> torch.Tensor:
8992
"""
9093
Sample errors from a detector error model (DEM) via cuST BitMatrixSampler.
@@ -115,9 +118,7 @@ def dem_sampling(
115118
if device_id is None:
116119
if H.is_cuda:
117120
device_index = H.device.index
118-
device_id = int(
119-
torch.cuda.current_device() if device_index is None else device_index
120-
)
121+
device_id = int(torch.cuda.current_device() if device_index is None else device_index)
121122
else:
122123
device_id = 0
123124

@@ -130,9 +131,7 @@ def dem_sampling(
130131
_cached_device_id = None
131132

132133
need_new = (
133-
_cached_sampler is None
134-
or batch_size > _cached_max_shots
135-
or _cached_device_id != device_id
134+
_cached_sampler is None or batch_size > _cached_max_shots or _cached_device_id != device_id
136135
)
137136

138137
if need_new:

code/qec/surface_code/memory_circuit_torch.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,9 +263,7 @@ def generate_batch(
263263
device_id = None
264264
if self.device.type == "cuda":
265265
device_index = self.device.index
266-
device_id = int(
267-
torch.cuda.current_device() if device_index is None else device_index
268-
)
266+
device_id = int(torch.cuda.current_device() if device_index is None else device_index)
269267
frames_xz = dem_sampling(
270268
self.H,
271269
self.p,

0 commit comments

Comments
 (0)