Skip to content

Commit f78ce8a

Browse files
authored
Expose det_order building in the python module (#65)
- The flag `--det-order-bfs` significantly improves the performance of tesseract for decoding e.g. color code circuit-level noise - There is no way to use it from python currently because the logic is contained in `tesseract_main.cc` - Here, we move it into utils and expose it in the python module
1 parent 547adc3 commit f78ce8a

File tree

5 files changed

+116
-89
lines changed

5 files changed

+116
-89
lines changed

src/py/utils_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def test_build_detector_graph():
4444
]
4545

4646

47+
def test_build_det_orders():
48+
assert tesseract_decoder.utils.build_det_orders(
49+
_DETECTOR_ERROR_MODEL, num_det_orders=1, seed=0
50+
) == [[0, 1]]
51+
52+
53+
def test_build_det_orders_no_bfs():
54+
assert tesseract_decoder.utils.build_det_orders(
55+
_DETECTOR_ERROR_MODEL, num_det_orders=1, det_order_bfs=False, seed=0
56+
) == [[0, 1]]
57+
58+
4759
def test_get_errors_from_dem():
4860
expected = "Error{cost=1.945910, symptom=Symptom{D0 }}, Error{cost=0.510826, symptom=Symptom{D0 D1 }}, Error{cost=1.098612, symptom=Symptom{D1 }}"
4961
assert (

src/tesseract_main.cc

Lines changed: 3 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,8 @@ struct Args {
171171

172172
// Sample orientations of the error model to use for the det priority
173173
{
174-
config.det_orders.resize(num_det_orders);
175-
std::mt19937_64 rng(det_order_seed);
176-
std::normal_distribution<double> dist(/*mean=*/0, /*stddev=*/1);
177-
178-
std::vector<std::vector<double>> detector_coords = get_detector_coords(config.dem);
179174
if (verbose) {
175+
auto detector_coords = get_detector_coords(config.dem);
180176
for (size_t d = 0; d < detector_coords.size(); ++d) {
181177
std::cout << "Detector D" << d << " coordinate (";
182178
size_t e = std::min(3ul, detector_coords[d].size());
@@ -187,88 +183,8 @@ struct Args {
187183
std::cout << ")" << std::endl;
188184
}
189185
}
190-
191-
if (det_order_bfs) {
192-
auto graph = build_detector_graph(config.dem);
193-
std::uniform_int_distribution<size_t> dist_det(0, graph.size() - 1);
194-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
195-
std::vector<size_t> perm;
196-
perm.reserve(graph.size());
197-
std::vector<char> visited(graph.size(), false);
198-
std::queue<size_t> q;
199-
size_t start = dist_det(rng);
200-
while (perm.size() < graph.size()) {
201-
if (!visited[start]) {
202-
visited[start] = true;
203-
q.push(start);
204-
perm.push_back(start);
205-
}
206-
while (!q.empty()) {
207-
size_t cur = q.front();
208-
q.pop();
209-
auto neigh = graph[cur];
210-
std::shuffle(neigh.begin(), neigh.end(), rng);
211-
for (size_t n : neigh) {
212-
if (!visited[n]) {
213-
visited[n] = true;
214-
q.push(n);
215-
perm.push_back(n);
216-
}
217-
}
218-
}
219-
if (perm.size() < graph.size()) {
220-
do {
221-
start = dist_det(rng);
222-
} while (visited[start]);
223-
}
224-
}
225-
std::vector<size_t> inv_perm(graph.size());
226-
for (size_t i = 0; i < perm.size(); ++i) {
227-
inv_perm[perm[i]] = i;
228-
}
229-
config.det_orders[det_order] = inv_perm;
230-
}
231-
} else {
232-
std::vector<double> inner_products(config.dem.count_detectors());
233-
234-
if (!detector_coords.size() || !detector_coords.at(0).size()) {
235-
// If there are no detector coordinates, just use the standard
236-
// ordering of the indices.
237-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
238-
config.det_orders[det_order].resize(config.dem.count_detectors());
239-
std::iota(config.det_orders[det_order].begin(), config.det_orders[det_order].end(), 0);
240-
}
241-
242-
} else {
243-
// Use the coordinates to order the detectors based on a random
244-
// orientation
245-
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
246-
// Sample a direction
247-
std::vector<double> orientation_vector;
248-
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
249-
orientation_vector.push_back(dist(rng));
250-
}
251-
252-
for (size_t i = 0; i < detector_coords.size(); ++i) {
253-
inner_products[i] = 0;
254-
for (size_t j = 0; j < orientation_vector.size(); ++j) {
255-
inner_products[i] += detector_coords[i][j] * orientation_vector[j];
256-
}
257-
}
258-
std::vector<size_t> perm(config.dem.count_detectors());
259-
std::iota(perm.begin(), perm.end(), 0);
260-
std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) {
261-
return inner_products[i] > inner_products[j];
262-
});
263-
// Invert the permutation
264-
std::vector<size_t> inv_perm(config.dem.count_detectors());
265-
for (size_t i = 0; i < perm.size(); ++i) {
266-
inv_perm[perm[i]] = i;
267-
}
268-
config.det_orders[det_order] = inv_perm;
269-
}
270-
}
271-
}
186+
config.det_orders =
187+
build_det_orders(config.dem, num_det_orders, det_order_bfs, det_order_seed);
272188
}
273189

274190
if (sample_num_shots > 0) {

src/utils.cc

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
#include <filesystem>
2020
#include <fstream>
2121
#include <iostream>
22+
#include <numeric>
23+
#include <queue>
2224
#include <random>
2325
#include <string>
2426

2527
#include "common.h"
2628
#include "stim.h"
2729

28-
std::vector<std::vector<double>> get_detector_coords(stim::DetectorErrorModel& dem) {
30+
std::vector<std::vector<double>> get_detector_coords(const stim::DetectorErrorModel& dem) {
2931
std::vector<std::vector<double>> detector_coords;
3032
for (const stim::DemInstruction& instruction : dem.flattened().instructions) {
3133
switch (instruction.type) {
@@ -79,6 +81,91 @@ std::vector<std::vector<size_t>> build_detector_graph(const stim::DetectorErrorM
7981
return neighbors;
8082
}
8183

84+
std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel& dem,
85+
size_t num_det_orders, bool det_order_bfs,
86+
uint64_t seed) {
87+
std::vector<std::vector<size_t>> det_orders(num_det_orders);
88+
std::mt19937_64 rng(seed);
89+
std::normal_distribution<double> dist(0, 1);
90+
91+
auto detector_coords = get_detector_coords(dem);
92+
93+
if (det_order_bfs) {
94+
auto graph = build_detector_graph(dem);
95+
std::uniform_int_distribution<size_t> dist_det(0, graph.size() - 1);
96+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
97+
std::vector<size_t> perm;
98+
perm.reserve(graph.size());
99+
std::vector<char> visited(graph.size(), false);
100+
std::queue<size_t> q;
101+
size_t start = dist_det(rng);
102+
while (perm.size() < graph.size()) {
103+
if (!visited[start]) {
104+
visited[start] = true;
105+
q.push(start);
106+
perm.push_back(start);
107+
}
108+
while (!q.empty()) {
109+
size_t cur = q.front();
110+
q.pop();
111+
auto neigh = graph[cur];
112+
std::shuffle(neigh.begin(), neigh.end(), rng);
113+
for (size_t n : neigh) {
114+
if (!visited[n]) {
115+
visited[n] = true;
116+
q.push(n);
117+
perm.push_back(n);
118+
}
119+
}
120+
}
121+
if (perm.size() < graph.size()) {
122+
do {
123+
start = dist_det(rng);
124+
} while (visited[start]);
125+
}
126+
}
127+
std::vector<size_t> inv_perm(graph.size());
128+
for (size_t i = 0; i < perm.size(); ++i) {
129+
inv_perm[perm[i]] = i;
130+
}
131+
det_orders[det_order] = inv_perm;
132+
}
133+
} else {
134+
std::vector<double> inner_products(dem.count_detectors());
135+
if (!detector_coords.size() || !detector_coords.at(0).size()) {
136+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
137+
det_orders[det_order].resize(dem.count_detectors());
138+
std::iota(det_orders[det_order].begin(), det_orders[det_order].end(), 0);
139+
}
140+
} else {
141+
for (size_t det_order = 0; det_order < num_det_orders; ++det_order) {
142+
std::vector<double> orientation_vector;
143+
for (size_t i = 0; i < detector_coords.at(0).size(); ++i) {
144+
orientation_vector.push_back(dist(rng));
145+
}
146+
147+
for (size_t i = 0; i < detector_coords.size(); ++i) {
148+
inner_products[i] = 0;
149+
for (size_t j = 0; j < orientation_vector.size(); ++j) {
150+
inner_products[i] += detector_coords[i][j] * orientation_vector[j];
151+
}
152+
}
153+
std::vector<size_t> perm(dem.count_detectors());
154+
std::iota(perm.begin(), perm.end(), 0);
155+
std::sort(perm.begin(), perm.end(), [&](const size_t& i, const size_t& j) {
156+
return inner_products[i] > inner_products[j];
157+
});
158+
std::vector<size_t> inv_perm(dem.count_detectors());
159+
for (size_t i = 0; i < perm.size(); ++i) {
160+
inv_perm[perm[i]] = i;
161+
}
162+
det_orders[det_order] = inv_perm;
163+
}
164+
}
165+
}
166+
return det_orders;
167+
}
168+
82169
bool sampling_from_dem(uint64_t seed, size_t num_shots, stim::DetectorErrorModel dem,
83170
std::vector<stim::SparseShot>& shots) {
84171
stim::DemSampler<stim::MAX_BITWORD_WIDTH> sampler(dem, std::mt19937_64{seed}, num_shots);

src/utils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,16 @@
2828

2929
constexpr const double EPSILON = 1e-7;
3030

31-
std::vector<std::vector<double>> get_detector_coords(stim::DetectorErrorModel& dem);
31+
std::vector<std::vector<double>> get_detector_coords(const stim::DetectorErrorModel& dem);
3232

3333
// Builds an adjacency list graph where two detectors share an edge iff an error
3434
// in the model activates them both.
3535
std::vector<std::vector<size_t>> build_detector_graph(const stim::DetectorErrorModel& dem);
3636

37+
std::vector<std::vector<size_t>> build_det_orders(const stim::DetectorErrorModel& dem,
38+
size_t num_det_orders, bool det_order_bfs = true,
39+
uint64_t seed = 0);
40+
3741
const double INF = std::numeric_limits<double>::infinity();
3842

3943
bool sampling_from_dem(uint64_t seed, size_t num_shots, stim::DetectorErrorModel dem,

src/utils.pybind.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ void add_utils_module(py::module &root) {
4242
return build_detector_graph(input_dem);
4343
},
4444
py::arg("dem"));
45+
m.def(
46+
"build_det_orders",
47+
[](py::object dem, size_t num_det_orders, bool det_order_bfs, uint64_t seed) {
48+
auto input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
49+
return build_det_orders(input_dem, num_det_orders, det_order_bfs, seed);
50+
},
51+
py::arg("dem"), py::arg("num_det_orders"), py::arg("det_order_bfs") = true,
52+
py::arg("seed") = 0);
4553
m.def(
4654
"get_errors_from_dem",
4755
[](py::object dem) {

0 commit comments

Comments
 (0)