Skip to content

Commit

Permalink
Improve the effect of the anc_detectors argument (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcSerraPeralta authored Oct 16, 2024
1 parent 39565cc commit 4e7242d
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 39 deletions.
7 changes: 5 additions & 2 deletions surface_sim/circuit_blocks/rot_surface_code_xzzx_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,12 @@ def qec_round_with_log_meas(

stab_type = "x_type" if rot_basis else "z_type"
stabs = layout.get_qubits(role="anc", stab_type=stab_type)
stabs = [s for s in stabs if s in anc_detectors]
detectors_stim = detectors.build_from_data(
model.meas_target, layout.adjacency_matrix(), anc_reset=True, anc_qubits=stabs
model.meas_target,
layout.adjacency_matrix(),
anc_reset=True,
reconstructable_stabs=stabs,
anc_qubits=anc_detectors,
)
circuit += detectors_stim

Expand Down
14 changes: 10 additions & 4 deletions surface_sim/circuit_blocks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,12 @@ def log_meas(
# detectors and logical observables
stab_type = "x_type" if rot_basis else "z_type"
stabs = layout.get_qubits(role="anc", stab_type=stab_type)
stabs = [s for s in stabs if s in anc_detectors]
detectors_stim = detectors.build_from_data(
model.meas_target, layout.adjacency_matrix(), anc_reset, anc_qubits=stabs
model.meas_target,
layout.adjacency_matrix(),
anc_reset,
reconstructable_stabs=stabs,
anc_qubits=anc_detectors,
)
circuit += detectors_stim

Expand Down Expand Up @@ -313,9 +316,12 @@ def log_meas_xzzx(
# detectors and logical observables
stab_type = "x_type" if rot_basis else "z_type"
stabs = layout.get_qubits(role="anc", stab_type=stab_type)
stabs = [s for s in stabs if s in anc_detectors]
detectors_stim = detectors.build_from_data(
model.meas_target, layout.adjacency_matrix(), anc_reset, anc_qubits=stabs
model.meas_target,
layout.adjacency_matrix(),
anc_reset,
reconstructable_stabs=stabs,
anc_qubits=anc_detectors,
)
circuit += detectors_stim

Expand Down
72 changes: 58 additions & 14 deletions surface_sim/detectors/detectors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable
from collections.abc import Callable, Iterable
from copy import deepcopy

import numpy as np
Expand All @@ -11,7 +11,7 @@


class Detectors:
def __init__(self, anc_qubits: list[str], frame: str) -> None:
def __init__(self, anc_qubits: Iterable[str], frame: str) -> None:
"""Initalises the ``Detectors`` class.
Parameters
Expand All @@ -34,6 +34,13 @@ def __init__(self, anc_qubits: list[str], frame: str) -> None:
Detector frame ``'r-1'`` build the detectors in the basis given by the
stabilizer generators of the previous-last-measured QEC round.
"""
if not isinstance(anc_qubits, Iterable):
raise TypeError(
f"'anc_qubits' must be iterable, but {type(anc_qubits)} was given."
)
if not isinstance(frame, str):
raise TypeError(f"'frame' must be a str, but {type(frame)} was given.")

self.anc_qubits = anc_qubits
self.frame = frame

Expand Down Expand Up @@ -121,7 +128,7 @@ def build_from_anc(
self,
get_rec: Callable,
anc_reset: bool,
anc_qubits: list[str] | None = None,
anc_qubits: Iterable[str] | None = None,
) -> stim.Circuit:
"""Returns the stim circuit with the corresponding detectors
given that the ancilla qubits have been measured.
Expand All @@ -143,6 +150,15 @@ def build_from_anc(
detectors_stim
Detectors defined in a ``stim`` circuit.
"""
if not (isinstance(anc_qubits, Iterable) or (anc_qubits is None)):
raise TypeError(
f"'anc_qubits' must be iterable or None, but {type(anc_qubits)} was given."
)
if not isinstance(get_rec, Callable):
raise TypeError(
f"'get_rec' must be callable, but {type(get_rec)} was given."
)

if self.frame == "1":
basis = self.init_gen
elif self.frame == "r":
Expand All @@ -154,6 +170,9 @@ def build_from_anc(
f"'frame' must be '1', 'r-1', or 'r', but {self.frame} was given."
)

if anc_qubits is None:
anc_qubits = self.curr_gen.stab_gen.values.tolist()

self.num_rounds += 1

detectors = _get_ancilla_meas_for_detectors(
Expand All @@ -164,13 +183,15 @@ def build_from_anc(
anc_reset_curr=anc_reset,
anc_reset_prev=anc_reset,
)
if anc_qubits is not None:
detectors = {anc: d for anc, d in detectors.items() if anc in anc_qubits}

# build the stim circuit
detectors_stim = stim.Circuit()
for targets in detectors.values():
detectors_rec = [get_rec(*t) for t in targets]
for anc, targets in detectors.items():
if anc in anc_qubits:
detectors_rec = [get_rec(*t) for t in targets]
else:
# create the detector but make it be always 0
detectors_rec = []
detectors_stim.append("DETECTOR", detectors_rec, [])

# update generators
Expand All @@ -183,7 +204,8 @@ def build_from_data(
get_rec: Callable,
adjacency_matrix: xr.DataArray,
anc_reset: bool,
anc_qubits: list[str] | None = None,
reconstructable_stabs: Iterable[str],
anc_qubits: Iterable[str] | None = None,
) -> stim.Circuit:
"""Returns the stim circuit with the corresponding detectors
given that the data qubits have been measured.
Expand All @@ -200,6 +222,8 @@ def build_from_data(
See ``qec_util.Layout.adjacency_matrix`` for more information.
anc_reset
Flag for if the ancillas are being reset in every QEC cycle.
reconstructable_stabs
Stabilizers that can be reconstructed from the data qubit outcomes.
anc_qubits
List of the ancilla qubits for which to build the detectors.
By default, builds all the detectors.
Expand All @@ -209,6 +233,20 @@ def build_from_data(
detectors_stim
Detectors defined in a ``stim`` circuit.
"""
if not isinstance(reconstructable_stabs, Iterable):
raise TypeError(
"'reconstructable_stabs' must be iterable, "
f"but {type(reconstructable_stabs)} was given."
)
if not (isinstance(anc_qubits, Iterable) or (anc_qubits is None)):
raise TypeError(
f"'anc_qubits' must be iterable or None, but {type(anc_qubits)} was given."
)
if not isinstance(get_rec, Callable):
raise TypeError(
f"'get_rec' must be callable, but {type(get_rec)} was given."
)

if self.frame == "1":
basis = self.init_gen
elif self.frame == "r":
Expand All @@ -220,6 +258,9 @@ def build_from_data(
f"'frame' must be '1', 'r-1', or 'r', but {self.frame} was given."
)

if anc_qubits is None:
anc_qubits = self.curr_gen.stab_gen.values.tolist()

self.num_rounds += 1

anc_detectors = _get_ancilla_meas_for_detectors(
Expand All @@ -230,10 +271,9 @@ def build_from_data(
anc_reset_curr=True,
anc_reset_prev=anc_reset,
)
if anc_qubits is not None:
anc_detectors = {
anc: d for anc, d in anc_detectors.items() if anc in anc_qubits
}
anc_detectors = {
anc: d for anc, d in anc_detectors.items() if anc in reconstructable_stabs
}

# udpate the (anc, -1) to a the corresponding set of (data, -1)
detectors = {}
Expand All @@ -257,8 +297,12 @@ def build_from_data(

# build the stim circuit
detectors_stim = stim.Circuit()
for targets in detectors.values():
detectors_rec = [get_rec(*t) for t in targets]
for anc, targets in detectors.items():
if anc in anc_qubits:
detectors_rec = [get_rec(*t) for t in targets]
else:
# create the detector but make it be always 0
detectors_rec = []
detectors_stim.append("DETECTOR", detectors_rec, [])

# update generators
Expand Down
49 changes: 37 additions & 12 deletions tests/detectors/test_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_detectors_update():
data=np.identity(2),
coords=dict(from_qubit=["X1", "Z1"], to_qubit=["D1", "D2"]),
),
reconstructable_stabs=anc_qubits,
)
assert (detectors.curr_gen == new_gen).all()
assert (detectors.prev_gen == new_gen).all()
Expand Down Expand Up @@ -270,7 +271,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="1")
detectors.num_rounds = 0
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=True)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=True, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand All @@ -280,7 +283,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="1")
detectors.num_rounds = 1
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=True)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=True, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand All @@ -290,7 +295,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="1")
detectors.num_rounds = 1
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=False)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=False, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand All @@ -300,7 +307,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="1")
detectors.num_rounds = 2
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=False)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=False, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand Down Expand Up @@ -333,7 +342,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="r")
detectors.num_rounds = 0
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=True)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=True, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand All @@ -343,7 +354,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="r")
detectors.num_rounds = 1
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=True)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=True, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand All @@ -353,7 +366,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="r")
detectors.num_rounds = 1
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=False)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=False, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand All @@ -363,7 +378,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="r")
detectors.num_rounds = 2
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=False)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=False, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand Down Expand Up @@ -396,7 +413,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="r-1")
detectors.num_rounds = 0
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=True)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=True, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand All @@ -406,7 +425,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="r-1")
detectors.num_rounds = 1
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=True)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=True, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand All @@ -416,7 +437,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="r-1")
detectors.num_rounds = 1
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=False)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=False, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand All @@ -426,7 +449,9 @@ def meas_rec(q, t):
detectors = Detectors(anc_qubits=anc_qubits, frame="r-1")
detectors.num_rounds = 2
detectors.update(unitary_mat)
detectors_stim = detectors.build_from_data(meas_rec, adj_matrix, anc_reset=False)
detectors_stim = detectors.build_from_data(
meas_rec, adj_matrix, anc_reset=False, reconstructable_stabs=anc_qubits
)
detector_rec = [
sorted([t.value for t in instr.targets_copy()]) for instr in detectors_stim
]
Expand Down
22 changes: 20 additions & 2 deletions tests/experiments/test_rot_surface_code_css.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,16 @@ def test_memory_experiment_anc_detectors():
rot_basis=True,
)

assert circuit.num_detectors == 10 + 1
num_anc = len(layout.get_qubits(role="anc"))
num_anc_x = len(layout.get_qubits(role="anc", stab_type="x_type"))
assert circuit.num_detectors == 10 * num_anc + num_anc_x

non_zero_dets = []
for instr in circuit.flattened():
if instr.name == "DETECTOR" and len(instr.targets_copy()) != 0:
non_zero_dets.append(instr)

assert len(non_zero_dets) == 10 + 1

return

Expand All @@ -100,6 +109,15 @@ def test_repeated_s_experiment_anc_detectors():
rot_basis=True,
)

assert circuit.num_detectors == 1 + 4 * 2 + 1
num_anc = len(layout.get_qubits(role="anc"))
num_anc_x = len(layout.get_qubits(role="anc", stab_type="x_type"))
assert circuit.num_detectors == (1 + 4 * 2) * num_anc + num_anc_x

non_zero_dets = []
for instr in circuit.flattened():
if instr.name == "DETECTOR" and len(instr.targets_copy()) != 0:
non_zero_dets.append(instr)

assert len(non_zero_dets) == 1 + 4 * 2 + 1

return
Loading

0 comments on commit 4e7242d

Please sign in to comment.