Skip to content

Commit 8d8a6c5

Browse files
authored
Enable pulling through Clifford operations, also add an option of only applying dd to single qubit gate moments (#6675)
1 parent 1640116 commit 8d8a6c5

File tree

2 files changed

+733
-57
lines changed

2 files changed

+733
-57
lines changed

cirq-core/cirq/transformers/dynamical_decoupling.py

Lines changed: 212 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,54 +15,57 @@
1515
"""Transformer pass that adds dynamical decoupling operations to a circuit."""
1616

1717
from functools import reduce
18-
from typing import Dict, Optional, Sequence, Tuple, Union
18+
from typing import Dict, Optional, Tuple, Union
19+
from itertools import cycle
1920

2021
from cirq.transformers import transformer_api
22+
from cirq.transformers.analytical_decompositions import single_qubit_decompositions
23+
from cirq.transformers.analytical_decompositions import unitary_to_pauli_string
2124
import cirq
2225
import numpy as np
2326

2427

25-
def _repeat_sequence(
26-
base_sequence: Sequence['cirq.Gate'], num_idle_moments: int
27-
) -> Sequence['cirq.Gate']:
28-
"""Returns the longest possible dynamical decoupling sequence."""
29-
repeat_times = num_idle_moments // len(base_sequence)
30-
return list(base_sequence) * repeat_times
31-
32-
33-
def _get_dd_sequence_from_schema_name(schema: str) -> Sequence['cirq.Gate']:
28+
def _get_dd_sequence_from_schema_name(schema: str) -> Tuple['cirq.Gate', ...]:
3429
"""Gets dynamical decoupling sequence from a schema name."""
35-
dd_sequence: Sequence['cirq.Gate']
3630
match schema:
31+
case 'DEFAULT':
32+
return (cirq.X, cirq.Y, cirq.X, cirq.Y)
3733
case 'XX_PAIR':
38-
dd_sequence = (cirq.X, cirq.X)
34+
return (cirq.X, cirq.X)
3935
case 'X_XINV':
40-
dd_sequence = (cirq.X, cirq.X**-1)
36+
return (cirq.X, cirq.X**-1)
4137
case 'YY_PAIR':
42-
dd_sequence = (cirq.Y, cirq.Y)
38+
return (cirq.Y, cirq.Y)
4339
case 'Y_YINV':
44-
dd_sequence = (cirq.Y, cirq.Y**-1)
40+
return (cirq.Y, cirq.Y**-1)
4541
case _:
4642
raise ValueError('Invalid schema name.')
47-
return dd_sequence
4843

4944

50-
def _validate_dd_sequence(dd_sequence: Sequence['cirq.Gate']) -> None:
45+
def _pauli_up_to_global_phase(gate: 'cirq.Gate') -> Union['cirq.Pauli', None]:
46+
for pauli_gate in [cirq.X, cirq.Y, cirq.Z]:
47+
if cirq.equal_up_to_global_phase(gate, pauli_gate):
48+
return pauli_gate
49+
return None
50+
51+
52+
def _validate_dd_sequence(dd_sequence: Tuple['cirq.Gate', ...]) -> None:
5153
"""Validates a given dynamical decoupling sequence.
5254
5355
Args:
5456
dd_sequence: Input dynamical sequence to be validated.
5557
56-
Returns:
57-
A tuple containing:
58-
- is_valid (bool): True if the dd sequence is valid, False otherwise.
59-
- error_message (str): An error message if the dd sequence is invalid, else None.
60-
6158
Raises:
6259
ValueError: If dd_sequence is not valid.
6360
"""
6461
if len(dd_sequence) < 2:
6562
raise ValueError('Invalid dynamical decoupling sequence. Expect more than one gates.')
63+
for gate in dd_sequence:
64+
if _pauli_up_to_global_phase(gate) is None:
65+
raise ValueError(
66+
'Dynamical decoupling sequence should only contain gates that are essentially'
67+
' Pauli gates.'
68+
)
6669
matrices = [cirq.unitary(gate) for gate in dd_sequence]
6770
product = reduce(np.matmul, matrices)
6871

@@ -73,50 +76,218 @@ def _validate_dd_sequence(dd_sequence: Sequence['cirq.Gate']) -> None:
7376
)
7477

7578

76-
def _parse_dd_sequence(schema: Union[str, Sequence['cirq.Gate']]) -> Sequence['cirq.Gate']:
79+
def _parse_dd_sequence(schema: Union[str, Tuple['cirq.Gate', ...]]) -> Tuple['cirq.Gate', ...]:
7780
"""Parses and returns dynamical decoupling sequence from schema."""
7881
if isinstance(schema, str):
79-
dd_sequence = _get_dd_sequence_from_schema_name(schema)
82+
return _get_dd_sequence_from_schema_name(schema)
8083
else:
8184
_validate_dd_sequence(schema)
82-
dd_sequence = schema
83-
return dd_sequence
85+
return schema
86+
87+
88+
def _is_single_qubit_operation(operation: 'cirq.Operation') -> bool:
89+
if len(operation.qubits) != 1:
90+
return False
91+
return True
92+
93+
94+
def _is_single_qubit_gate_moment(moment: 'cirq.Moment') -> bool:
95+
for operation in moment:
96+
if not _is_single_qubit_operation(operation):
97+
return False
98+
return True
99+
100+
101+
def _is_clifford_moment(moment: 'cirq.Moment') -> bool:
102+
for op in moment.operations:
103+
if op.gate is not None and isinstance(op.gate, cirq.MeasurementGate):
104+
return False
105+
if not cirq.has_stabilizer_effect(op):
106+
return False
107+
return True
108+
109+
110+
def _get_clifford_pieces(circuit: 'cirq.AbstractCircuit') -> list[Tuple[int, int]]:
111+
clifford_pieces: list[Tuple[int, int]] = []
112+
left = 0
113+
for moment_id, moment in enumerate(circuit):
114+
if not _is_clifford_moment(moment):
115+
clifford_pieces.append((left, moment_id))
116+
left = moment_id + 1
117+
if left < len(circuit):
118+
clifford_pieces.append((left, len(circuit)))
119+
return clifford_pieces
120+
121+
122+
def _is_insertable_moment(moment: 'cirq.Moment', single_qubit_gate_moments_only: bool) -> bool:
123+
return _is_single_qubit_gate_moment(moment) or not single_qubit_gate_moments_only
124+
125+
126+
def _calc_pulled_through(
127+
moment: 'cirq.Moment', input_pauli_ops: 'cirq.PauliString'
128+
) -> 'cirq.PauliString':
129+
"""Calculates the pulled_through after pulling through moment with the input.
130+
131+
We assume that the moment is Clifford here. Then, pulling through is essentially
132+
decomposing a matrix into Pauli operations on each qubit.
133+
"""
134+
pulled_through: 'cirq.PauliString' = cirq.PauliString()
135+
for affected_q, combined_op_in_pauli in input_pauli_ops.items():
136+
op_at_moment = moment.operation_at(affected_q)
137+
if op_at_moment is None:
138+
pulled_through *= combined_op_in_pauli.on(affected_q)
139+
continue
140+
prev_circuit = cirq.Circuit(cirq.Moment(op_at_moment))
141+
new_circuit = cirq.Circuit(
142+
cirq.Moment(combined_op_in_pauli.on(affected_q)), cirq.Moment(op_at_moment)
143+
)
144+
qubit_order = op_at_moment.qubits
145+
pulled_through_pauli_ops = unitary_to_pauli_string(
146+
prev_circuit.unitary(qubit_order=qubit_order)
147+
@ new_circuit.unitary(qubit_order=qubit_order).conj().T
148+
)
149+
if pulled_through_pauli_ops is not None:
150+
for qid, gate in enumerate(pulled_through_pauli_ops):
151+
pulled_through *= gate.on(qubit_order[qid])
152+
return pulled_through
153+
154+
155+
def _merge_pulled_through(
156+
mutable_circuit: 'cirq.Circuit',
157+
pulled_through: 'cirq.PauliString',
158+
clifford_piece_range: Tuple[int, int],
159+
single_qubit_gate_moments_only: bool,
160+
) -> 'cirq.PauliString':
161+
"""Merges pulled through Pauli gates into the last single-qubit gate operation or the insert it
162+
into the first idle moment if idle moments exist.
163+
Args:
164+
mutable_circuit: Mutable circuit to transform.
165+
pulled_through: Pauli gates to be merged.
166+
clifford_piece_range: Specifies the [l, r) moments within which pulled-through gate merging
167+
is to be performed.
168+
single_qubit_gate_moments_only: If set True, dynamical decoupling operation will only be
169+
added in single-qubit gate moments.
170+
171+
Returns:
172+
The remaining pulled through operations after merging.
173+
"""
174+
insert_intos: list[Tuple[int, 'cirq.Operation']] = []
175+
batch_replaces: list[Tuple[int, 'cirq.Operation', 'cirq.Operation']] = []
176+
remaining_pulled_through = pulled_through
177+
for affected_q, combined_op_in_pauli in pulled_through.items():
178+
moment_id = mutable_circuit.prev_moment_operating_on([affected_q], clifford_piece_range[1])
179+
if moment_id is not None:
180+
op = mutable_circuit.operation_at(affected_q, moment_id)
181+
# Try to merge op into an existing single-qubit gate operation.
182+
if op is not None and _is_single_qubit_operation(op):
183+
updated_gate_mat = cirq.unitary(combined_op_in_pauli) @ cirq.unitary(op)
184+
updated_gate: Optional['cirq.Gate'] = (
185+
single_qubit_decompositions.single_qubit_matrix_to_phxz(updated_gate_mat)
186+
)
187+
if updated_gate is None:
188+
# updated_gate is close to Identity.
189+
updated_gate = cirq.I
190+
batch_replaces.append((moment_id, op, updated_gate.on(affected_q)))
191+
remaining_pulled_through *= combined_op_in_pauli.on(affected_q)
192+
continue
193+
# Insert into the first empty moment for the qubit if such moment exists.
194+
while moment_id < clifford_piece_range[1]:
195+
if affected_q not in mutable_circuit.moments[
196+
moment_id
197+
].qubits and _is_insertable_moment(
198+
mutable_circuit.moments[moment_id], single_qubit_gate_moments_only
199+
):
200+
insert_intos.append((moment_id, combined_op_in_pauli.on(affected_q)))
201+
remaining_pulled_through *= combined_op_in_pauli.on(affected_q)
202+
break
203+
moment_id += 1
204+
mutable_circuit.batch_insert_into(insert_intos)
205+
mutable_circuit.batch_replace(batch_replaces)
206+
return remaining_pulled_through
84207

85208

86209
@transformer_api.transformer
87210
def add_dynamical_decoupling(
88211
circuit: 'cirq.AbstractCircuit',
89212
*,
90213
context: Optional['cirq.TransformerContext'] = None,
91-
schema: Union[str, Sequence['cirq.Gate']] = 'X_XINV',
214+
schema: Union[str, Tuple['cirq.Gate', ...]] = 'DEFAULT',
215+
single_qubit_gate_moments_only: bool = True,
92216
) -> 'cirq.Circuit':
93-
"""Adds dynamical decoupling gate operations to idle moments of a given circuit.
94-
This transformer preserves the moment structure of the circuit.
217+
"""Adds dynamical decoupling gate operations to a given circuit.
218+
This transformer might add a new moment after each piece of Clifford moments, so the original
219+
moment structure could change.
95220
96221
Args:
97222
circuit: Input circuit to transform.
98223
context: `cirq.TransformerContext` storing common configurable options for transformers.
99224
schema: Dynamical decoupling schema name or a dynamical decoupling sequence.
100225
If a schema is specified, provided dynamical decouping sequence will be used.
101226
Otherwise, customized dynamical decoupling sequence will be applied.
227+
single_qubit_gate_moments_only: If set True, dynamical decoupling operation will only be
228+
added in single-qubit gate moments.
102229
103230
Returns:
104231
A copy of the input circuit with dynamical decoupling operations.
105232
"""
106-
last_busy_moment_by_qubits: Dict['cirq.Qid', int] = {q: 0 for q in circuit.all_qubits()}
107-
insert_into: list[Tuple[int, 'cirq.OP_TREE']] = []
233+
base_dd_sequence: Tuple['cirq.Gate', ...] = _parse_dd_sequence(schema)
234+
mutable_circuit = circuit.unfreeze(copy=True)
108235

109-
base_dd_sequence = _parse_dd_sequence(schema)
236+
pauli_map: Dict['cirq.Gate', 'cirq.Pauli'] = {}
237+
for gate in base_dd_sequence:
238+
pauli_gate = _pauli_up_to_global_phase(gate)
239+
if pauli_gate is not None:
240+
pauli_map[gate] = pauli_gate
110241

242+
busy_moment_range_by_qubit: Dict['cirq.Qid', list[int]] = {
243+
q: [len(circuit), -1] for q in circuit.all_qubits()
244+
}
111245
for moment_id, moment in enumerate(circuit):
112246
for q in moment.qubits:
113-
insert_gates = _repeat_sequence(
114-
base_dd_sequence, num_idle_moments=moment_id - last_busy_moment_by_qubits[q] - 1
115-
)
116-
for idx, gate in enumerate(insert_gates):
117-
insert_into.append((last_busy_moment_by_qubits[q] + idx + 1, gate.on(q)))
118-
last_busy_moment_by_qubits[q] = moment_id
247+
busy_moment_range_by_qubit[q][0] = min(busy_moment_range_by_qubit[q][0], moment_id)
248+
busy_moment_range_by_qubit[q][1] = max(busy_moment_range_by_qubit[q][1], moment_id)
249+
clifford_pieces = _get_clifford_pieces(circuit)
250+
251+
insert_intos: list[Tuple[int, 'cirq.Operation']] = []
252+
insert_moments: list[Tuple[int, 'cirq.Moment']] = []
253+
for l, r in clifford_pieces: # [l, r)
254+
# A PauliString stores the result of 'pulling' Pauli gates past each operations
255+
# right before the current moment.
256+
pulled_through: 'cirq.PauliString' = cirq.PauliString()
257+
iter_by_qubits = {q: cycle(base_dd_sequence) for q in circuit.all_qubits()}
258+
259+
# Iterate over the Clifford piece.
260+
for moment_id in range(l, r):
261+
moment = circuit.moments[moment_id]
262+
263+
# Insert
264+
if _is_insertable_moment(moment, single_qubit_gate_moments_only):
265+
for q in circuit.all_qubits() - moment.qubits:
266+
if (
267+
busy_moment_range_by_qubit[q][0]
268+
< moment_id
269+
< busy_moment_range_by_qubit[q][1]
270+
):
271+
insert_gate = next(iter_by_qubits[q])
272+
insert_intos.append((moment_id, insert_gate.on(q)))
273+
pulled_through *= pauli_map[insert_gate].on(q)
274+
275+
# Pull through
276+
pulled_through = _calc_pulled_through(moment, pulled_through)
277+
278+
mutable_circuit.batch_insert_into(insert_intos)
279+
insert_intos.clear()
280+
281+
pulled_through = _merge_pulled_through(
282+
mutable_circuit, pulled_through, (l, r), single_qubit_gate_moments_only
283+
)
284+
285+
# Insert a new moment if there are remaining pulled through operations.
286+
new_moment_ops = []
287+
for affected_q, combined_op_in_pauli in pulled_through.items():
288+
new_moment_ops.append(combined_op_in_pauli.on(affected_q))
289+
if len(new_moment_ops) != 0:
290+
insert_moments.append((r, cirq.Moment(new_moment_ops)))
119291

120-
updated_circuit = circuit.unfreeze(copy=True)
121-
updated_circuit.batch_insert_into(insert_into)
122-
return updated_circuit
292+
mutable_circuit.batch_insert(insert_moments)
293+
return mutable_circuit

0 commit comments

Comments
 (0)