@@ -24,12 +24,11 @@ bool Node::operator>(const Node& other) const {
2424
2525double TesseractDecoder::get_detcost (size_t d,
2626 const std::vector<char >& blocked_errs,
27- const std::vector<size_t >& det_counts,
28- const std::vector<char >& dets) const {
27+ const std::vector<size_t >& det_counts) const {
2928 double min_cost = INF;
3029 for (size_t ei : d2e[d]) {
3130 if (!blocked_errs[ei]) {
32- double ecost = ( errors[ei].likelihood_cost ) / det_counts[ei];
31+ double ecost = errors[ei].likelihood_cost / det_counts[ei];
3332 min_cost = std::min (min_cost, ecost);
3433 assert (det_counts[ei]);
3534 }
@@ -46,7 +45,7 @@ TesseractDecoder::TesseractDecoder(TesseractConfig config_) : config(config_) {
4645 assert (config.det_orders [i].size () == config.dem .count_detectors ());
4746 }
4847 }
49- assert (this -> config .det_orders .size ());
48+ assert (config.det_orders .size ());
5049 errors = get_errors_from_dem (config.dem .flattened ());
5150 if (config.verbose ) {
5251 for (auto & error : errors) {
@@ -120,7 +119,7 @@ void TesseractDecoder::decode_to_errors(
120119 size_t det_order = beam % config.det_orders .size ();
121120 decode_to_errors (detections, det_order);
122121 double this_cost = cost_from_errors (predicted_errors_buffer);
123- if (!low_confidence_flag and this_cost < best_cost) {
122+ if (!low_confidence_flag && this_cost < best_cost) {
124123 best_errors = predicted_errors_buffer;
125124 best_cost = this_cost;
126125 }
@@ -137,7 +136,7 @@ void TesseractDecoder::decode_to_errors(
137136 ++det_order) {
138137 decode_to_errors (detections, det_order);
139138 double this_cost = cost_from_errors (predicted_errors_buffer);
140- if (!low_confidence_flag and this_cost < best_cost) {
139+ if (!low_confidence_flag && this_cost < best_cost) {
141140 best_errors = predicted_errors_buffer;
142141 best_cost = this_cost;
143142 }
@@ -153,7 +152,7 @@ void TesseractDecoder::decode_to_errors(
153152 }
154153 config.det_beam = max_det_beam;
155154 predicted_errors_buffer = best_errors;
156- low_confidence_flag = ( best_cost == std::numeric_limits<double >::max () );
155+ low_confidence_flag = best_cost == std::numeric_limits<double >::max ();
157156}
158157
159158bool QNode::operator >(const QNode& other) const {
@@ -183,20 +182,16 @@ void TesseractDecoder::to_node(const QNode& qnode,
183182 // Reconstruct the blocked_errs
184183 for (size_t oei : d2e[min_det]) {
185184 node.blocked_errs [oei] = true ;
186- if (!config.at_most_two_errors_per_detector and oei == ei) break ;
185+ if (!config.at_most_two_errors_per_detector && oei == ei) break ;
187186 }
188187
189188 // Reconstruct the dets
190189 for (size_t d : edets[ei]) {
191- if (node.dets [d]) {
192- node.dets [d] = false ;
193- if (config.at_most_two_errors_per_detector ) {
194- for (size_t oei : d2e[d]) {
195- node.blocked_errs [oei] = true ;
196- }
190+ node.dets [d] = !node.dets [d];
191+ if (!node.dets [d] && config.at_most_two_errors_per_detector ) {
192+ for (size_t oei : d2e[d]) {
193+ node.blocked_errs [oei] = true ;
197194 }
198- } else {
199- node.dets [d] = true ;
200195 }
201196 }
202197 }
@@ -217,40 +212,37 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
217212 std::unordered_set<std::vector<char >, VectorCharHash>>
218213 discovered_dets;
219214
220- size_t min_num_dets;
221- {
222- std::vector<size_t > errs;
223- std::vector<char > blocked_errs (num_errors, false );
224- std::vector<size_t > det_counts (num_errors, 0 );
215+ size_t min_num_dets = detections.size ();
216+ std::vector<size_t > errs;
217+ std::vector<char > blocked_errs (num_errors, false );
218+ std::vector<size_t > det_counts (num_errors, 0 );
225219
226- for (size_t d = 0 ; d < num_detectors; ++d) {
227- if (!dets[d]) continue ;
228- for (int ei : d2e[d]) {
229- det_counts[ei]++;
230- }
231- }
232- double initial_cost = 0.0 ;
233- for (size_t d = 0 ; d < num_detectors; ++d) {
234- if (!dets[d]) continue ;
235- initial_cost += get_detcost (d, blocked_errs, det_counts, dets);
220+ for (size_t d = 0 ; d < num_detectors; ++d) {
221+ if (!dets[d]) continue ;
222+ for (int ei : d2e[d]) {
223+ ++det_counts[ei];
236224 }
237- if (initial_cost == INF) {
238- low_confidence_flag = true ;
239- return ;
240- }
241- min_num_dets =
242- static_cast <size_t >(std::count (dets.begin (), dets.end (), true ));
243- // pq.push({errs, dets, initial_cost, min_num_dets, blocked_errs});
244- pq.push ({initial_cost, min_num_dets, errs});
245225 }
246- size_t num_pq_pushed = 1 ;
226+ double initial_cost = 0.0 ;
227+ for (size_t d = 0 ; d < num_detectors; ++d) {
228+ if (!dets[d]) continue ;
229+ initial_cost += get_detcost (d, blocked_errs, det_counts);
230+ }
231+ if (initial_cost == INF) {
232+ low_confidence_flag = true ;
233+ return ;
234+ }
235+ // pq.push({errs, dets, initial_cost, min_num_dets, blocked_errs});
236+ pq.push ({initial_cost, min_num_dets, errs});
247237
238+ size_t num_pq_pushed = 1 ;
248239 size_t max_num_dets = min_num_dets + det_beam;
249240 Node node;
250241 std::vector<size_t > next_det_counts;
251242 std::vector<char > next_next_blocked_errs;
252243 std::vector<char > next_dets;
253244 std::vector<size_t > next_errs;
245+
254246 while (!pq.empty ()) {
255247 const QNode qnode = pq.top ();
256248 if (qnode.num_dets > max_num_dets) {
@@ -280,13 +272,12 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
280272 }
281273 // Store the predicted errors into the buffer
282274 predicted_errors_buffer = node.errs ;
283-
284275 return ;
285276 }
286277
287278 if (node.num_dets > max_num_dets) continue ;
288279
289- if (config.no_revisit_dets and
280+ if (config.no_revisit_dets &&
290281 !discovered_dets[node.num_dets ].insert (node.dets ).second ) {
291282 continue ;
292283 }
@@ -336,9 +327,10 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
336327 for (size_t d = 0 ; d < num_detectors; ++d) {
337328 if (!node.dets [d]) continue ;
338329 for (int ei : d2e[d]) {
339- det_counts[ei]++ ;
330+ ++ det_counts[ei];
340331 }
341332 }
333+
342334 // We cache as we recompute the det costs
343335 std::vector<double > det_costs (num_detectors, -1 );
344336 std::vector<char > next_blocked_errs = node.blocked_errs ;
@@ -362,19 +354,14 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
362354 // iteration
363355 if (last_ei != std::numeric_limits<size_t >::max ()) {
364356 for (int d : edets[last_ei]) {
365- if (node.dets [d]) {
366- for (int oei : d2e[d]) {
367- ++next_det_counts[oei];
368- }
369- } else {
370- for (int oei : d2e[d]) {
371- --next_det_counts[oei];
372- }
357+ int fired = node.dets [d] ? 1 : -1 ;
358+ for (int oei : d2e[d]) {
359+ next_det_counts[oei] += fired;
373360 }
374361 }
375362 }
376- last_ei = ei;
377363
364+ last_ei = ei;
378365 next_blocked_errs[ei] = true ;
379366
380367 next_errs = node.errs ;
@@ -384,24 +371,21 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
384371 double next_cost = node.cost + errors[ei].likelihood_cost ;
385372
386373 size_t next_num_dets = node.num_dets ;
387- next_next_blocked_errs = next_blocked_errs;
374+ if (config.at_most_two_errors_per_detector ) {
375+ next_next_blocked_errs = next_blocked_errs;
376+ }
377+
388378 for (int d : edets[ei]) {
389- if (next_dets[d]) {
390- next_dets[d] = false ;
391- --next_num_dets;
392- for (int oei : d2e[d]) {
393- --next_det_counts[oei];
394- }
395- if (config.at_most_two_errors_per_detector ) {
396- for (size_t oei : d2e[d]) {
397- next_next_blocked_errs[oei] = true ;
398- }
399- }
400- } else {
401- next_dets[d] = true ;
402- ++next_num_dets;
403- for (int oei : d2e[d]) {
404- ++next_det_counts[oei];
379+ next_dets[d] = !next_dets[d];
380+ int fired = next_dets[d] ? 1 : -1 ;
381+ next_num_dets += fired;
382+ for (int oei : d2e[d]) {
383+ next_det_counts[oei] += fired;
384+ }
385+
386+ if (!next_dets[d] && config.at_most_two_errors_per_detector ) {
387+ for (size_t oei : d2e[d]) {
388+ next_next_blocked_errs[oei] = true ;
405389 }
406390 }
407391 }
@@ -410,7 +394,7 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
410394 continue ;
411395 }
412396
413- if (config.no_revisit_dets and
397+ if (config.no_revisit_dets &&
414398 discovered_dets[next_num_dets].find (next_dets) !=
415399 discovered_dets[next_num_dets].end ()) {
416400 continue ;
@@ -420,23 +404,22 @@ void TesseractDecoder::decode_to_errors(const std::vector<uint64_t>& detections,
420404 if (node.dets [d]) {
421405 if (det_costs[d] == -1 ) {
422406 det_costs[d] =
423- get_detcost (d, node.blocked_errs , det_counts, node. dets );
407+ get_detcost (d, node.blocked_errs , det_counts);
424408 }
425409 next_cost -= det_costs[d];
426410 } else {
427- next_cost += get_detcost (d, next_next_blocked_errs, next_det_counts,
428- next_dets);
411+ next_cost += get_detcost (d, config.at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts);
429412 }
430413 }
431414 for (size_t od : eneighbors[ei]) {
432415 if (!node.dets [od] || !next_dets[od]) continue ;
433416 if (det_costs[od] == -1 ) {
434417 det_costs[od] =
435- get_detcost (od, node.blocked_errs , det_counts, node. dets );
418+ get_detcost (od, node.blocked_errs , det_counts);
436419 }
437420 next_cost -= det_costs[od];
438421 next_cost +=
439- get_detcost (od, next_next_blocked_errs, next_det_counts, next_dets );
422+ get_detcost (od, config. at_most_two_errors_per_detector ? next_next_blocked_errs : next_blocked_errs, next_det_counts );
440423 }
441424
442425 if (next_cost == INF) {
@@ -496,4 +479,4 @@ void TesseractDecoder::decode_shots(
496479 for (size_t i = 0 ; i < shots.size (); ++i) {
497480 obs_predicted[i] = decode (shots[i].hits );
498481 }
499- }
482+ }
0 commit comments