Skip to content

Commit b241ec2

Browse files
Merge branch 'main' into z_cal2
2 parents 5d668e2 + 2c914ce commit b241ec2

23 files changed

+1212
-335
lines changed

cirq-core/cirq/work/sampler.py

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,7 @@
1414
"""Abstract base class for things sampling quantum circuits."""
1515

1616
import collections
17-
from itertools import islice
18-
from typing import (
19-
Dict,
20-
FrozenSet,
21-
Iterator,
22-
List,
23-
Optional,
24-
Sequence,
25-
Tuple,
26-
TypeVar,
27-
TYPE_CHECKING,
28-
Union,
29-
)
17+
from typing import Dict, FrozenSet, List, Optional, Sequence, Tuple, TypeVar, TYPE_CHECKING, Union
3018

3119
import duet
3220
import pandas as pd
@@ -49,14 +37,6 @@
4937
class Sampler(metaclass=value.ABCMetaImplementAnyOneOf):
5038
"""Something capable of sampling quantum circuits. Simulator or hardware."""
5139

52-
# Users have a rate limit of 1000 QPM for read/write requests to
53-
# the Quantum Engine. The sampler will poll from the DB every 1s
54-
# for inflight requests for results. Empirically, for circuits
55-
# sent in run_batch, sending circuits in CHUNK_SIZE=5 for large
56-
# number of circuits (> 200) with large depths (100 layers)
57-
# does not encounter quota exceeded issues for non-streaming cases.
58-
CHUNK_SIZE: int = 5
59-
6040
def run(
6141
self,
6242
program: 'cirq.AbstractCircuit',
@@ -311,32 +291,16 @@ async def run_batch_async(
311291
programs: Sequence['cirq.AbstractCircuit'],
312292
params_list: Optional[Sequence['cirq.Sweepable']] = None,
313293
repetitions: Union[int, Sequence[int]] = 1,
294+
limiter: duet.Limiter = duet.Limiter(10),
314295
) -> Sequence[Sequence['cirq.Result']]:
315296
"""Runs the supplied circuits asynchronously.
316297
317298
See docs for `cirq.Sampler.run_batch`.
318299
"""
319300
params_list, repetitions = self._normalize_batch_args(programs, params_list, repetitions)
320-
if len(programs) <= self.CHUNK_SIZE:
321-
return await duet.pstarmap_async(
322-
self.run_sweep_async, zip(programs, params_list, repetitions)
323-
)
324-
325-
results = []
326-
for program_chunk, params_chunk, reps_chunk in zip(
327-
_chunked(programs, self.CHUNK_SIZE),
328-
_chunked(params_list, self.CHUNK_SIZE),
329-
_chunked(repetitions, self.CHUNK_SIZE),
330-
):
331-
# Run_sweep_async for the current chunk
332-
await duet.sleep(1) # Delay for 1 second between chunk
333-
results.extend(
334-
await duet.pstarmap_async(
335-
self.run_sweep_async, zip(program_chunk, params_chunk, reps_chunk)
336-
)
337-
)
338-
339-
return results
301+
return await duet.pstarmap_async(
302+
self.run_sweep_async, zip(programs, params_list, repetitions, [limiter] * len(programs))
303+
)
340304

341305
def _normalize_batch_args(
342306
self,
@@ -489,8 +453,3 @@ def _get_measurement_shapes(
489453
)
490454
num_instances[key] += 1
491455
return {k: (num_instances[k], qid_shape) for k, qid_shape in qid_shapes.items()}
492-
493-
494-
def _chunked(iterable: Sequence[T], n: int) -> Iterator[tuple[T, ...]]:
495-
it = iter(iterable)
496-
return iter(lambda: tuple(islice(it, n)), ())

cirq-core/cirq/work/sampler_test.py

Lines changed: 3 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
"""Tests for cirq.Sampler."""
1515
from typing import Sequence
16-
from unittest import mock
1716

1817
import pytest
1918

@@ -224,7 +223,9 @@ async def test_run_batch_async_calls_run_sweep_asynchronously():
224223
params_list = [params1, params2]
225224

226225
class AsyncSampler(cirq.Sampler):
227-
async def run_sweep_async(self, program, params, repetitions: int = 1):
226+
async def run_sweep_async(
227+
self, program, params, repetitions: int = 1, unused: duet.Limiter = duet.Limiter(None)
228+
):
228229
if params == params1:
229230
await duet.sleep(0.001)
230231

@@ -267,55 +268,6 @@ def test_sampler_run_batch_bad_input_lengths():
267268
)
268269

269270

270-
@mock.patch('duet.pstarmap_async')
271-
@pytest.mark.parametrize('call_count', [1, 2, 3])
272-
@duet.sync
273-
async def test_run_batch_async_sends_circuits_in_chunks(spy, call_count):
274-
class AsyncSampler(cirq.Sampler):
275-
CHUNK_SIZE = 3
276-
277-
async def run_sweep_async(self, _, params, __: int = 1):
278-
pass # pragma: no cover
279-
280-
sampler = AsyncSampler()
281-
a = cirq.LineQubit(0)
282-
circuit_list = [cirq.Circuit(cirq.X(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))] * (
283-
sampler.CHUNK_SIZE * call_count
284-
)
285-
param_list = [cirq.Points('t', [0.3, 0.7])] * (sampler.CHUNK_SIZE * call_count)
286-
287-
await sampler.run_batch_async(circuit_list, params_list=param_list)
288-
289-
assert spy.call_count == call_count
290-
291-
292-
@pytest.mark.parametrize('call_count', [1, 2, 3])
293-
@duet.sync
294-
async def test_run_batch_async_runs_runs_sequentially(call_count):
295-
a = cirq.LineQubit(0)
296-
finished = []
297-
circuit1 = cirq.Circuit(cirq.X(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))
298-
circuit2 = cirq.Circuit(cirq.Y(a) ** sympy.Symbol('t'), cirq.measure(a, key='m'))
299-
params1 = cirq.Points('t', [0.3, 0.7])
300-
params2 = cirq.Points('t', [0.4, 0.6])
301-
302-
class AsyncSampler(cirq.Sampler):
303-
CHUNK_SIZE = 1
304-
305-
async def run_sweep_async(self, _, params, __: int = 1):
306-
if params == params1:
307-
await duet.sleep(0.001)
308-
309-
finished.append(params)
310-
311-
sampler = AsyncSampler()
312-
circuit_list = [circuit1, circuit2] * call_count
313-
param_list = [params1, params2] * call_count
314-
await sampler.run_batch_async(circuit_list, params_list=param_list)
315-
316-
assert finished == param_list
317-
318-
319271
def test_sampler_simple_sample_expectation_values():
320272
a = cirq.LineQubit(0)
321273
sampler = cirq.Simulator()

cirq-google/cirq_google/api/v2/program.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
syntax = "proto3";
22

3+
import "tunits/proto/tunits.proto";
4+
35
package cirq.google.api.v2;
46

57
option java_package = "com.google.cirq.google.api.v2";
@@ -296,6 +298,7 @@ message ArgValue {
296298
RepeatedInt64 int64_values = 5;
297299
RepeatedDouble double_values = 6;
298300
RepeatedString string_values = 7;
301+
tunits.Value value_with_unit = 8;
299302
}
300303
}
301304

cirq-google/cirq_google/api/v2/program_pb2.py

Lines changed: 96 additions & 95 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cirq-google/cirq_google/api/v2/program_pb2.pyi

Lines changed: 8 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

cirq-google/cirq_google/engine/engine_job.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,14 @@ def delete(self) -> None:
262262
"""Deletes the job and result, if any."""
263263
self.context.client.delete_job(self.project_id, self.program_id, self.job_id)
264264

265-
async def results_async(self) -> Sequence[EngineResult]:
265+
async def results_async(
266+
self, limiter: duet.Limiter = duet.Limiter(None)
267+
) -> Sequence[EngineResult]:
266268
"""Returns the job results, blocking until the job is complete."""
267269
import cirq_google.engine.engine as engine_base
268270

269271
if self._results is None:
270-
result_response = await self._await_result_async()
272+
result_response = await self._await_result_async(limiter)
271273
result = result_response.result
272274
result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
273275
if (
@@ -286,7 +288,9 @@ async def results_async(self) -> Sequence[EngineResult]:
286288
raise ValueError(f'invalid result proto version: {result_type}')
287289
return self._results
288290

289-
async def _await_result_async(self) -> quantum.QuantumResult:
291+
async def _await_result_async(
292+
self, limiter: duet.Limiter = duet.Limiter(None)
293+
) -> quantum.QuantumResult:
290294
if self._job_result_future is not None:
291295
response = await self._job_result_future
292296
if isinstance(response, quantum.QuantumResult):
@@ -299,12 +303,13 @@ async def _await_result_async(self) -> quantum.QuantumResult:
299303
'Internal error: The job response type is not recognized.'
300304
) # pragma: no cover
301305

302-
async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type]
303-
while True:
304-
job = await self._refresh_job_async()
305-
if job.execution_status.state in TERMINAL_STATES:
306-
break
307-
await duet.sleep(1)
306+
async with limiter:
307+
async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type]
308+
while True:
309+
job = await self._refresh_job_async()
310+
if job.execution_status.state in TERMINAL_STATES:
311+
break
312+
await duet.sleep(1)
308313
_raise_on_failure(job)
309314
response = await self.context.client.get_job_results_async(
310315
self.project_id, self.program_id, self.job_id

cirq-google/cirq_google/engine/processor_sampler.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import cirq
1818
import duet
19+
from cirq_google.engine.engine_job import EngineJob
1920

2021
if TYPE_CHECKING:
2122
import cirq_google as cg
@@ -58,9 +59,14 @@ def __init__(
5859
self._run_name = run_name
5960
self._snapshot_id = snapshot_id
6061
self._device_config_name = device_config_name
62+
self._result_limiter = duet.Limiter(None)
6163

6264
async def run_sweep_async(
63-
self, program: 'cirq.AbstractCircuit', params: cirq.Sweepable, repetitions: int = 1
65+
self,
66+
program: 'cirq.AbstractCircuit',
67+
params: cirq.Sweepable,
68+
repetitions: int = 1,
69+
limiter: duet.Limiter = duet.Limiter(None),
6470
) -> Sequence['cg.EngineResult']:
6571
job = await self._processor.run_sweep_async(
6672
program=program,
@@ -70,6 +76,10 @@ async def run_sweep_async(
7076
snapshot_id=self._snapshot_id,
7177
device_config_name=self._device_config_name,
7278
)
79+
80+
if isinstance(job, EngineJob):
81+
return await job.results_async(limiter)
82+
7383
return await job.results_async()
7484

7585
run_sweep = duet.sync(run_sweep_async)
@@ -79,10 +89,12 @@ async def run_batch_async(
7989
programs: Sequence[cirq.AbstractCircuit],
8090
params_list: Optional[Sequence[cirq.Sweepable]] = None,
8191
repetitions: Union[int, Sequence[int]] = 1,
92+
limiter: duet.Limiter = duet.Limiter(10),
8293
) -> Sequence[Sequence['cg.EngineResult']]:
94+
self._result_limiter = limiter
8395
return cast(
8496
Sequence[Sequence['cg.EngineResult']],
85-
await super().run_batch_async(programs, params_list, repetitions),
97+
await super().run_batch_async(programs, params_list, repetitions, self._result_limiter),
8698
)
8799

88100
run_batch = duet.sync(run_batch_async)
@@ -102,3 +114,7 @@ def snapshot_id(self) -> str:
102114
@property
103115
def device_config_name(self) -> str:
104116
return self._device_config_name
117+
118+
@property
119+
def result_limiter(self) -> duet.Limiter:
120+
return self._result_limiter

cirq-google/cirq_google/engine/processor_sampler_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import cirq
2020
import cirq_google as cg
2121
from cirq_google.engine.abstract_processor import AbstractProcessor
22+
from cirq_google.engine.engine_job import EngineJob
2223

2324

2425
@pytest.mark.parametrize('circuit', [cirq.Circuit(), cirq.FrozenCircuit()])
@@ -169,6 +170,31 @@ def test_run_batch_differing_repetitions():
169170
)
170171

171172

173+
def test_run_batch_receives_results_using_limiter():
174+
processor = mock.create_autospec(AbstractProcessor)
175+
run_name = "RUN_NAME"
176+
device_config_name = "DEVICE_CONFIG_NAME"
177+
sampler = cg.ProcessorSampler(
178+
processor=processor, run_name=run_name, device_config_name=device_config_name
179+
)
180+
181+
job = mock.AsyncMock(EngineJob)
182+
183+
processor.run_sweep_async.return_value = job
184+
a = cirq.LineQubit(0)
185+
circuit1 = cirq.Circuit(cirq.X(a))
186+
circuit2 = cirq.Circuit(cirq.Y(a))
187+
params1 = [cirq.ParamResolver({'t': 1})]
188+
params2 = [cirq.ParamResolver({'t': 2})]
189+
circuits = [circuit1, circuit2]
190+
params_list = [params1, params2]
191+
repetitions = [1, 2]
192+
193+
sampler.run_batch(circuits, params_list, repetitions)
194+
195+
job.results_async.assert_called_with(sampler.result_limiter)
196+
197+
172198
def test_processor_sampler_processor_property():
173199
processor = mock.create_autospec(AbstractProcessor)
174200
sampler = cg.ProcessorSampler(processor=processor)

cirq-google/cirq_google/serialization/arg_func_langs.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from cirq_google.api import v2
2121
from cirq_google.ops import InternalGate
2222
from cirq.qis import CliffordTableau
23+
import tunits
2324

2425
SUPPORTED_FUNCTIONS_FOR_LANGUAGE: Dict[Optional[str], FrozenSet[str]] = {
2526
'': frozenset(),
@@ -33,8 +34,10 @@
3334
SUPPORTED_SYMPY_OPS = (sympy.Symbol, sympy.Add, sympy.Mul, sympy.Pow)
3435

3536
# Argument types for gates.
36-
ARG_LIKE = Union[int, float, numbers.Real, Sequence[bool], str, sympy.Expr]
37-
ARG_RETURN_LIKE = Union[float, int, str, List[bool], List[int], List[float], List[str], sympy.Expr]
37+
ARG_LIKE = Union[int, float, numbers.Real, Sequence[bool], str, sympy.Expr, tunits.Value]
38+
ARG_RETURN_LIKE = Union[
39+
float, int, str, List[bool], List[int], List[float], List[str], sympy.Expr, tunits.Value
40+
]
3841
FLOAT_ARG_LIKE = Union[float, sympy.Expr]
3942

4043
# Types for comparing floats
@@ -182,6 +185,8 @@ def arg_to_proto(
182185
)
183186
field, types_tuple = numerical_fields[cur_index]
184187
field.extend(types_tuple[0](x) for x in value)
188+
elif isinstance(value, tunits.Value):
189+
msg.arg_value.value_with_unit.MergeFrom(value.to_proto())
185190
else:
186191
_arg_func_to_proto(value, arg_function_language, msg)
187192

@@ -329,6 +334,8 @@ def arg_from_proto(
329334
return [float(v) for v in arg_value.double_values.values]
330335
if which_val == 'string_values':
331336
return [str(v) for v in arg_value.string_values.values]
337+
if which_val == 'value_with_unit':
338+
return tunits.Value.from_proto(arg_value.value_with_unit)
332339
raise ValueError(f'Unrecognized value type: {which_val!r}')
333340

334341
if which == 'symbol':

0 commit comments

Comments
 (0)