Skip to content

Commit b38721f

Browse files
committed
Eliminate need for using NumPy 2 printoptions()
Further study of the cases and the methods available in NumPy led to a solution that doesn't require using `printoptions(legacy="1.25")` to prevent data types from showing up in things like circuit diagrams.
1 parent 5a67797 commit b38721f

File tree

3 files changed

+32
-25
lines changed

3 files changed

+32
-25
lines changed

cirq-core/cirq/_compat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,10 @@ def _print(self, expr, **kwargs):
192192
if hasattr(value, "__qualname__"):
193193
return f"{value.__module__}.{value.__qualname__}"
194194

195-
with np.printoptions(legacy='1.25'):
196-
return repr(value)
195+
if isinstance(value, np.number):
196+
return repr(value.item())
197+
198+
return repr(value)
197199

198200

199201
def dataclass_repr(value: Any, namespace: str = 'cirq') -> str:

cirq-core/cirq/ops/fsim_gate.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import cmath
2525
import math
26-
from typing import AbstractSet, Any, Dict, Iterator, Optional, Tuple
26+
from typing import AbstractSet, Any, Dict, Iterator, Optional, Tuple, Union
2727

2828
import numpy as np
2929
import sympy
@@ -52,6 +52,12 @@ def _half_pi_mod_pi(param: 'cirq.TParamVal') -> bool:
5252
return param in (-np.pi / 2, np.pi / 2, -sympy.pi / 2, sympy.pi / 2)
5353

5454

55+
def _plainvalue(value: Union[int, float, complex, np.number]) -> Union[int, float, complex]:
56+
"""Returns a plain Python number if the given value is a NumPy number.
57+
Used to avoid a change in repr behavior introduced in NumPy 2."""
58+
return value.item() if isinstance(value, np.number) else value
59+
60+
5561
@value.value_equality(approximate=True)
5662
class FSimGate(gate_features.InterchangeableQubitsGate, raw_types.Gate):
5763
r"""Fermionic simulation gate.
@@ -196,10 +202,9 @@ def _decompose_(self, qubits) -> Iterator['cirq.OP_TREE']:
196202
yield cirq.CZ(a, b) ** (-self.phi / np.pi)
197203

198204
def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs') -> Tuple[str, ...]:
199-
with np.printoptions(legacy='1.25'):
200-
t = args.format_radians(self.theta)
201-
p = args.format_radians(self.phi)
202-
return f'FSim({t}, {p})', f'FSim({t}, {p})'
205+
t = args.format_radians(_plainvalue(self.theta))
206+
p = args.format_radians(_plainvalue(self.phi))
207+
return f'FSim({t}, {p})', f'FSim({t}, {p})'
203208

204209
def __pow__(self, power) -> 'FSimGate':
205210
return FSimGate(cirq.mul(self.theta, power), cirq.mul(self.phi, power))
@@ -477,16 +482,15 @@ def to_exponent(angle_rads: 'cirq.TParamVal') -> 'cirq.TParamVal':
477482
yield cirq.Z(q1) ** to_exponent(after[1])
478483

479484
def _circuit_diagram_info_(self, args: 'cirq.CircuitDiagramInfoArgs') -> Tuple[str, ...]:
480-
with np.printoptions(legacy='1.25'):
481-
theta = args.format_radians(self.theta)
482-
zeta = args.format_radians(self.zeta)
483-
chi = args.format_radians(self.chi)
484-
gamma = args.format_radians(self.gamma)
485-
phi = args.format_radians(self.phi)
486-
return (
487-
f'PhFSim({theta}, {zeta}, {chi}, {gamma}, {phi})',
488-
f'PhFSim({theta}, {zeta}, {chi}, {gamma}, {phi})',
489-
)
485+
theta = args.format_radians(_plainvalue(self.theta))
486+
zeta = args.format_radians(_plainvalue(self.zeta))
487+
chi = args.format_radians(_plainvalue(self.chi))
488+
gamma = args.format_radians(_plainvalue(self.gamma))
489+
phi = args.format_radians(_plainvalue(self.phi))
490+
return (
491+
f'PhFSim({theta}, {zeta}, {chi}, {gamma}, {phi})',
492+
f'PhFSim({theta}, {zeta}, {chi}, {gamma}, {phi})',
493+
)
490494

491495
def __repr__(self) -> str:
492496
theta = proper_repr(self.theta)

cirq-core/cirq/protocols/circuit_diagram_info_protocol.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -243,14 +243,15 @@ def __repr__(self) -> str:
243243
)
244244

245245
def format_real(self, val: Union[sympy.Basic, int, float]) -> str:
246-
with np.printoptions(legacy='1.25'):
247-
if isinstance(val, sympy.Basic):
248-
return str(val)
249-
if val == int(val):
250-
return str(int(val))
251-
if self.precision is None:
252-
return str(val)
253-
return f'{float(val):.{self.precision}}'
246+
if isinstance(val, sympy.Basic):
247+
return str(val)
248+
if isinstance(val, np.number):
249+
val = val.item()
250+
if val == int(val):
251+
return str(int(val))
252+
if self.precision is None:
253+
return str(val)
254+
return f'{float(val):.{self.precision}}'
254255

255256
def format_complex(self, val: Union[sympy.Basic, int, float, 'cirq.TParamValComplex']) -> str:
256257
if isinstance(val, sympy.Basic):

0 commit comments

Comments
 (0)