Skip to content

Commit 39b1537

Browse files
Add plotting method and calculate the variance of estimated fidelity
1 parent 6f71abf commit 39b1537

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

cirq-core/cirq/experiments/xeb_fitting.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ def benchmark_2q_xeb_fidelities(
9595
df['e_u'] = np.sum(pure_probs**2, axis=1)
9696
df['u_u'] = np.sum(pure_probs, axis=1) / D
9797
df['m_u'] = np.sum(pure_probs * sampled_probs, axis=1)
98+
# Var[m_u] = Var[sum p(x) * p_sampled(x)]
99+
# = sum p(x)^2 Var[p_sampled(x)]
100+
# = sum p(x)^2 p(x) (1 - p(x))
101+
# = sum p(x)^3 (1 - p(x))
102+
df['var_m_u'] = np.sum(pure_probs**3 * (1 - pure_probs), axis=1)
98103
df['y'] = df['m_u'] - df['u_u']
99104
df['x'] = df['e_u'] - df['u_u']
100105
df['numerator'] = df['x'] * df['y']
@@ -103,7 +108,11 @@ def benchmark_2q_xeb_fidelities(
103108
def per_cycle_depth(df):
104109
"""This function is applied per cycle_depth in the following groupby aggregation."""
105110
fid_lsq = df['numerator'].sum() / df['denominator'].sum()
106-
ret = {'fidelity': fid_lsq}
111+
# Note: both df['denominator'] an df['x'] are constants.
112+
# Var[f] = Var[df['numerator']] / (sum df['denominator'])^2
113+
# = sum (df['x']^2 * df['var_m_u']) / (sum df['denominator'])^2
114+
var_fid = (df['var_m_u'] * df['x'] ** 2).sum() / df['denominator'].sum() ** 2
115+
ret = {'fidelity': fid_lsq, 'fidelity_variance': var_fid}
107116

108117
def _try_keep(k):
109118
"""If all the values for a key `k` are the same in this group, we can keep it."""
@@ -678,7 +687,7 @@ def _per_pair(f1):
678687
'cycle_depths': f1['cycle_depth'].values,
679688
'fidelities': f1['fidelity'].values,
680689
'a_std': a_std,
681-
'layer_fid_std': layer_fid_std,
690+
'layer_fid_std': np.sqrt(layer_fid_std**2 + f1['fidelity_variance'].values),
682691
}
683692
return pd.Series(record)
684693

cirq-core/cirq/experiments/z_phase_calibration.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import cirq
2626
import pandas as pd
2727
import multiprocessing
28+
import matplotlib.pyplot as plt
2829

2930

3031
def z_phase_calibration_workflow(
@@ -109,10 +110,12 @@ def z_phase_calibration_workflow(
109110
pool=pool,
110111
)
111112

112-
return result, xeb_fitting.before_and_after_characterization(
113+
before_after = xeb_fitting.before_and_after_characterization(
113114
fids_df_0, characterization_result=result
114115
)
115116

117+
return result, before_after
118+
116119

117120
def calibrate_z_phases(
118121
sampler: 'cirq.Sampler',
@@ -191,3 +194,39 @@ def calibrate_z_phases(
191194
params['gamma'] = params.get('gamma', options.gamma_default or 0)
192195
gates[pair] = ops.PhasedFSimGate(**params)
193196
return gates
197+
198+
199+
def plot_z_phase_calibration_result(
200+
before_after_df: 'pd.DataFrame',
201+
axes: np.ndarray[Sequence[Sequence['plt.Axes']], np.dtype[np.object_]],
202+
*,
203+
with_error_bars: bool = False,
204+
) -> None:
205+
"""A helper method to plot the result of running z-phase calibration.
206+
207+
Args:
208+
before_after_df: The second return object of running `z_phase_calibration_workflow`.
209+
axes: And ndarray of the axes to plot on.
210+
The number of axes is expected to be >= number of qubit pairs.
211+
with_error_bars: Whether to add error bars or not.
212+
The width of the bar is an upper bound on standard variation of the estimated fidelity.
213+
"""
214+
for pair, ax in zip(before_after_df.index, axes.flatten()):
215+
row = before_after_df.loc[[pair]].iloc[0]
216+
ax.errorbar(
217+
row.cycle_depths_0,
218+
row.fidelities_0,
219+
yerr=row.layer_fid_std_0 * with_error_bars,
220+
label='original',
221+
)
222+
ax.errorbar(
223+
row.cycle_depths_0,
224+
row.fidelities_c,
225+
yerr=row.layer_fid_std_c * with_error_bars,
226+
label='calibrated',
227+
)
228+
ax.axhline(1, linestyle='--')
229+
ax.set_xlabel('cycle depth')
230+
ax.set_ylabel('fidelity estimate')
231+
ax.set_title('-'.join(str(q) for q in pair))
232+
ax.legend()

0 commit comments

Comments
 (0)