Skip to content

Commit ea08c0f

Browse files
Create a visualization class and wire it through the code (#83)
Currently visualization is created by parsing the logs ... this PR introduces a different way to do this by creating a visualization class ... for now this class can only write to file to be inline with how things are done now but can be improved in different ways ... for example it can record data that the python layer can use directly with matplotlib to use the class set the flag `create_visualization=True` in the TesseractConfig and after running the decoding call `decoder.visualizer.write(file_path)` to write to the desired file
1 parent 0a7aa24 commit ea08c0f

File tree

11 files changed

+145
-10
lines changed

11 files changed

+145
-10
lines changed

src/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ pybind_library(
7070
"common.pybind.h",
7171
"utils.pybind.h",
7272
"simplex.pybind.h",
73+
"visualization.pybind.h",
7374
"tesseract.pybind.h",
7475
],
7576
deps = [
@@ -113,6 +114,19 @@ cc_library(
113114
],
114115
)
115116

117+
cc_library(
118+
name = "libviz",
119+
srcs = ["visualization.cc"],
120+
hdrs = ["visualization.h"],
121+
copts = OPT_COPTS,
122+
linkopts = OPT_LINKOPTS,
123+
deps = [
124+
":libutils",
125+
"@boost//:dynamic_bitset",
126+
],
127+
128+
)
129+
116130
cc_library(
117131
name = "libtesseract",
118132
srcs = ["tesseract.cc"],
@@ -121,6 +135,7 @@ cc_library(
121135
linkopts = OPT_LINKOPTS,
122136
deps = [
123137
":libutils",
138+
":libviz",
124139
"@boost//:dynamic_bitset",
125140
],
126141
)

src/common.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#include "common.h"
1616

17-
std::string common::Symptom::str() {
17+
std::string common::Symptom::str() const {
1818
std::string s = "Symptom{";
1919
for (size_t d : detectors) {
2020
s += "D" + std::to_string(d);
@@ -63,7 +63,7 @@ common::Error::Error(const stim::DemInstruction& error) {
6363
symptom.observables = observables;
6464
}
6565

66-
std::string common::Error::str() {
66+
std::string common::Error::str() const {
6767
return "Error{cost=" + std::to_string(likelihood_cost) + ", symptom=" + symptom.str() + "}";
6868
}
6969

src/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ struct Symptom {
4242
bool operator==(const Symptom& other) const {
4343
return detectors == other.detectors && observables == other.observables;
4444
}
45-
std::string str();
45+
std::string str() const;
4646
};
4747

4848
// Represents an error / weighted hyperedge
@@ -53,7 +53,7 @@ struct Error {
5353
Error(double likelihood_cost, std::vector<int>& detectors, std::vector<int> observables)
5454
: likelihood_cost(likelihood_cost), symptom{detectors, observables} {}
5555
Error(const stim::DemInstruction& error);
56-
std::string str();
56+
std::string str() const;
5757

5858
// Get/calculate the probability from the likelihood cost.
5959
double get_probability() const;

src/py/tesseract_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
def test_create_config():
4444
assert (
4545
str(tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL))
46-
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0)"
46+
== "TesseractConfig(dem=DetectorErrorModel_Object, det_beam=65535, no_revisit_dets=0, at_most_two_errors_per_detector=0, verbose=0, pqlimit=18446744073709551615, det_orders=[], det_penalty=0, create_visualization=0)"
4747
)
4848
assert (
4949
tesseract_decoder.tesseract.TesseractConfig(_DETECTOR_ERROR_MODEL).dem

src/tesseract.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ std::string TesseractConfig::str() {
6161
ss << "verbose=" << config.verbose << ", ";
6262
ss << "pqlimit=" << config.pqlimit << ", ";
6363
ss << "det_orders=" << config.det_orders << ", ";
64-
ss << "det_penalty=" << config.det_penalty << ")";
64+
ss << "det_penalty=" << config.det_penalty << ", ";
65+
ss << "create_visualization=" << config.create_visualization;
66+
ss << ")";
6567
return ss.str();
6668
}
6769

@@ -124,6 +126,11 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) {
124126
num_errors = config.dem.count_errors();
125127
num_observables = config.dem.count_observables();
126128
initialize_structures(config.dem.count_detectors());
129+
if (config.create_visualization) {
130+
auto detectors = get_detector_coords(config.dem);
131+
visualizer.add_detector_coords(detectors);
132+
visualizer.add_errors(errors);
133+
}
127134
}
128135

129136
void TesseractDecoder::initialize_structures(size_t num_detectors) {
@@ -294,6 +301,10 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
294301
flip_detectors_and_block_errors(detector_order, node.errors, detectors, detector_cost_tuples);
295302

296303
if (node.num_detectors == 0) {
304+
if (config.create_visualization) {
305+
visualizer.add_activated_errors(node.errors);
306+
visualizer.add_activated_detectors(detectors, num_detectors);
307+
}
297308
if (config.verbose) {
298309
std::cout << "activated_errors = ";
299310
for (size_t oei : node.errors) {
@@ -318,6 +329,10 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
318329
if (config.no_revisit_dets && !visited_detectors[node.num_detectors].insert(detectors).second)
319330
continue;
320331

332+
if (config.create_visualization) {
333+
visualizer.add_activated_errors(node.errors);
334+
visualizer.add_activated_detectors(detectors, num_detectors);
335+
}
321336
if (config.verbose) {
322337
std::cout.precision(13);
323338
std::cout << "len(pq) = " << pq.size() << " num_pq_pushed = " << num_pq_pushed << std::endl;

src/tesseract.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "common.h"
2626
#include "stim.h"
2727
#include "utils.h"
28+
#include "visualization.h"
2829

2930
constexpr size_t INF_DET_BEAM = std::numeric_limits<uint16_t>::max();
3031

@@ -38,6 +39,7 @@ struct TesseractConfig {
3839
size_t pqlimit = std::numeric_limits<size_t>::max();
3940
std::vector<std::vector<size_t>> det_orders;
4041
double det_penalty = 0;
42+
bool create_visualization = false;
4143

4244
std::string str();
4345
};
@@ -64,6 +66,8 @@ struct ErrorCost {
6466

6567
struct TesseractDecoder {
6668
TesseractConfig config;
69+
Visualizer visualizer;
70+
6771
explicit TesseractDecoder(TesseractConfig config);
6872

6973
// Clears the predicted_errors_buffer and fills it with the decoded errors for

src/tesseract.pybind.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "pybind11/detail/common.h"
2222
#include "simplex.pybind.h"
2323
#include "utils.pybind.h"
24+
#include "visualization.pybind.h"
2425

2526
PYBIND11_MODULE(tesseract_decoder, tesseract) {
2627
py::module::import("stim");
@@ -29,6 +30,7 @@ PYBIND11_MODULE(tesseract_decoder, tesseract) {
2930
add_utils_module(tesseract);
3031
add_simplex_module(tesseract);
3132
add_tesseract_module(tesseract);
33+
add_visualization_module(tesseract);
3234

3335
// Adds a context manager to the python library that can be used to redirect C++'s stdout/stderr
3436
// to python's stdout/stderr at run time like

src/tesseract.pybind.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ TesseractConfig tesseract_config_maker(
3737
bool no_revisit_dets = false, bool at_most_two_errors_per_detector = false,
3838
bool verbose = false, size_t pqlimit = std::numeric_limits<size_t>::max(),
3939
std::vector<std::vector<size_t>> det_orders = std::vector<std::vector<size_t>>(),
40-
double det_penalty = 0.0) {
40+
double det_penalty = 0.0, bool create_visualization = false) {
4141
stim::DetectorErrorModel input_dem = parse_py_object<stim::DetectorErrorModel>(dem);
4242
return TesseractConfig({input_dem, det_beam, beam_climbing, no_revisit_dets,
4343
at_most_two_errors_per_detector, verbose, pqlimit, det_orders,
44-
det_penalty});
44+
det_penalty, create_visualization});
4545
}
4646
}; // namespace
4747
void add_tesseract_module(py::module& root) {
@@ -61,6 +61,7 @@ void add_tesseract_module(py::module& root) {
6161
py::arg("at_most_two_errors_per_detector") = false, py::arg("verbose") = false,
6262
py::arg("pqlimit") = std::numeric_limits<size_t>::max(),
6363
py::arg("det_orders") = std::vector<std::vector<size_t>>(), py::arg("det_penalty") = 0.0,
64+
py::arg("create_visualization") = false,
6465
R"pbdoc(
6566
The constructor for the `TesseractConfig` class.
6667
@@ -86,6 +87,8 @@ void add_tesseract_module(py::module& root) {
8687
will generate its own orderings.
8788
det_penalty : float, default=0.0
8889
A penalty value added to the cost of each detector visited.
90+
create_visualization: bool, defualt=False
91+
Whether to record the information needed to create a visualization or not.
8992
)pbdoc")
9093
.def_property("dem", &dem_getter<TesseractConfig>, &dem_setter<TesseractConfig>,
9194
"The `stim.DetectorErrorModel` that defines the error channels and detectors.")
@@ -106,6 +109,8 @@ void add_tesseract_module(py::module& root) {
106109
"A list of pre-specified detector orderings.")
107110
.def_readwrite("det_penalty", &TesseractConfig::det_penalty,
108111
"The penalty cost added for each detector.")
112+
.def_readwrite("create_visualization", &TesseractConfig::create_visualization,
113+
"If True, records necessary information to create visualization.")
109114
.def("__str__", &TesseractConfig::str)
110115
.def("compile_decoder", &_compile_tesseract_decoder_helper,
111116
py::return_value_policy::take_ownership,
@@ -374,7 +379,10 @@ void add_tesseract_module(py::module& root) {
374379
.def_readwrite("errors", &TesseractDecoder::errors,
375380
"The list of all errors in the detector error model.")
376381
.def_readwrite("num_observables", &TesseractDecoder::num_observables,
377-
"The total number of logical observables in the detector error model.");
382+
"The total number of logical observables in the detector error model.")
383+
.def_readonly("visualizer", &TesseractDecoder::visualizer,
384+
"An object that can (if config.create_visualization=True) be used to generate "
385+
"visualization of the algorithm");
378386
}
379387

380-
#endif
388+
#endif

src/visualization.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
2+
#include "visualization.h"
3+
4+
void Visualizer::add_errors(const std::vector<common::Error>& errors) {
5+
for (auto& error : errors) {
6+
lines.push_back(error.str());
7+
}
8+
}
9+
void Visualizer::add_detector_coords(const std::vector<std::vector<double>>& detector_coords) {
10+
for (size_t d = 0; d < detector_coords.size(); ++d) {
11+
std::stringstream ss;
12+
ss << "Detector D" << d << " coordinate (";
13+
size_t e = std::min(3ul, detector_coords[d].size());
14+
for (size_t i = 0; i < e; ++i) {
15+
ss << detector_coords[d][i];
16+
if (i + 1 < e) ss << ", ";
17+
}
18+
ss << ")";
19+
lines.push_back(ss.str());
20+
}
21+
}
22+
23+
void Visualizer::add_activated_errors(const std::vector<size_t>& activated_errors) {
24+
std::stringstream ss;
25+
ss << "activated_errors = ";
26+
for (size_t oei : activated_errors) {
27+
ss << oei << ", ";
28+
}
29+
lines.push_back(ss.str());
30+
}
31+
32+
void Visualizer::add_activated_detectors(const boost::dynamic_bitset<>& detectors,
33+
size_t num_detectors) {
34+
std::stringstream ss;
35+
ss << "activated_detectors = ";
36+
for (size_t d = 0; d < num_detectors; ++d) {
37+
if (detectors[d]) {
38+
ss << d << ", ";
39+
}
40+
}
41+
lines.push_back(ss.str());
42+
}
43+
44+
void Visualizer::write(const char* fpath) {
45+
FILE* fout = fopen(fpath, "w");
46+
47+
for (std::string& line : lines) {
48+
fprintf(fout, line.c_str());
49+
fputs("\n", fout);
50+
}
51+
52+
fclose(fout);
53+
}

src/visualization.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef _VISUALIZATION_H
2+
#define _VISUALIZATION_H
3+
4+
#include <boost/dynamic_bitset.hpp>
5+
#include <list>
6+
#include <vector>
7+
8+
#include "common.h"
9+
10+
struct Visualizer {
11+
void add_detector_coords(const std::vector<std::vector<double>> &);
12+
void add_errors(const std::vector<common::Error> &);
13+
void add_activated_errors(const std::vector<size_t> &);
14+
void add_activated_detectors(const boost::dynamic_bitset<> &, size_t);
15+
16+
void write(const char *fpath);
17+
18+
private:
19+
std::list<std::string> lines;
20+
};
21+
22+
#endif

0 commit comments

Comments
 (0)