diff --git a/src/api/hocrrenderer.cpp b/src/api/hocrrenderer.cpp index f3ea4f5452..886b6bf2ab 100644 --- a/src/api/hocrrenderer.cpp +++ b/src/api/hocrrenderer.cpp @@ -133,7 +133,7 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) { if (tesseract_ == nullptr || (page_res_ == nullptr && Recognize(monitor) < 0)) return nullptr; - int lcnt = 1, bcnt = 1, pcnt = 1, wcnt = 1, scnt = 1, tcnt = 1, gcnt = 1; + int lcnt = 1, bcnt = 1, pcnt = 1, wcnt = 1, scnt = 1, tcnt = 1, ccnt = 1; int page_id = page_number + 1; // hOCR uses 1-based page numbers. bool para_is_ltr = true; // Default direction is LTR const char* paragraph_lang = nullptr; @@ -230,20 +230,14 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) { // Now, process the word... int32_t lstm_choice_mode = tesseract_->lstm_choice_mode; - std::vector>>* rawTimestepMap = - nullptr; - std::vector>>* choiceMap = + std::vector>>>* rawTimestepMap = nullptr; std::vector>>* CTCMap = nullptr; - std::vector>>>* - symbolMap = nullptr; if (lstm_choice_mode) { - choiceMap = res_it->GetBestLSTMSymbolChoices(); - symbolMap = res_it->GetSegmentedLSTMTimesteps(); + CTCMap = res_it->GetBestLSTMSymbolChoices(); rawTimestepMap = res_it->GetRawLSTMTimesteps(); - CTCMap = res_it->GetBestCTCSymbolChoices(); } hocr_str << "\n "; + for (auto timestep : *symbol) { + hocr_str << "\n "; + for (auto conf : timestep) { + hocr_str << "\n " << HOcrEscape(conf.first).c_str() + << ""; + ++ccnt; + } + hocr_str << ""; + ++tcnt; + } + hocr_str << "\n "; + ++scnt; + } else if (lstm_choice_mode == 2) { + tesseract::ChoiceIterator ci(*res_it); + hocr_str << "\n "; + do { + const char* choice = ci.GetUTF8Text(); + float choiceconf = ci.Confidence(); + if (choice != nullptr) { + hocr_str << "\n " + << HOcrEscape(choice).c_str() << ""; + ccnt++; + } + } while (ci.Next()); + hocr_str << "\n "; + tcnt++; + } } } res_it->Next(RIL_SYMBOL); @@ -309,44 +353,32 @@ char* TessBaseAPI::GetHOCRText(ETEXT_DESC* monitor, int page_number) { if (italic) hocr_str << ""; if (bold) hocr_str << ""; // If the lstm choice mode is required it is added here - if (lstm_choice_mode == 1 && rawTimestepMap != nullptr) { - for (auto timestep : *rawTimestepMap) { - hocr_str << "\n "; - for (std::pair conf : timestep) { - hocr_str << "" - << HOcrEscape(conf.first).c_str() << ""; - gcnt++; - } - hocr_str << ""; - tcnt++; - } - } else if (lstm_choice_mode == 2 && choiceMap != nullptr) { - for (auto timestep : *choiceMap) { - if (timestep.size() > 0) { - hocr_str << "\n "; + for (auto timestep : symbol) { + hocr_str << "\n "; - for (auto & j : timestep) { - hocr_str << "" << HOcrEscape(j.first).c_str() << ""; - gcnt++; + << " title='x_confs " << int(conf.second * 100) << "'>" + << HOcrEscape(conf.first).c_str() << ""; + ++ccnt; } hocr_str << ""; - tcnt++; + ++tcnt; } + hocr_str << ""; + ++scnt; } - } else if (lstm_choice_mode == 4) { + } else if (lstm_choice_mode == 2 && !hocr_boxes && CTCMap != nullptr) { for (auto timestep : *CTCMap) { if (timestep.size() > 0) { hocr_str << "\n 100.0f) conf = 100.0f; - hocr_str << "" << HOcrEscape(j.first).c_str() << ""; - gcnt++; + ccnt++; } hocr_str << ""; tcnt++; } } - } else if (lstm_choice_mode == 3 && symbolMap != nullptr) { - for (auto timesteps : *symbolMap) { - hocr_str << "\n "; - for (auto timestep : timesteps) { - hocr_str << "\n "; - for (std::pair conf : timestep) { - hocr_str << "" - << HOcrEscape(conf.first).c_str() << ""; - gcnt++; - } - hocr_str << ""; - tcnt++; - } - hocr_str << ""; - scnt++; - } - } + } // Close ocrx_word. if (hocr_boxes || lstm_choice_mode > 0) { hocr_str << "\n "; } hocr_str << ""; tcnt = 1; - gcnt = 1; + ccnt = 1; wcnt++; // Close any ending block/paragraph/textline. if (last_word_in_line) { diff --git a/src/ccmain/linerec.cpp b/src/ccmain/linerec.cpp index f4399ce3d1..004d2e265c 100644 --- a/src/ccmain/linerec.cpp +++ b/src/ccmain/linerec.cpp @@ -240,7 +240,7 @@ void Tesseract::LSTMRecognizeWord(const BLOCK& block, ROW *row, WERD_RES *word, lstm_recognizer_->RecognizeLine(*im_data, true, classify_debug_level > 0, kWorstDictCertainty / kCertaintyScale, word_box, words, lstm_choice_mode, - lstm_choice_amount); + lstm_choice_iterations); delete im_data; SearchWords(words); } diff --git a/src/ccmain/ltrresultiterator.cpp b/src/ccmain/ltrresultiterator.cpp index f709c9d424..8276f43ebc 100644 --- a/src/ccmain/ltrresultiterator.cpp +++ b/src/ccmain/ltrresultiterator.cpp @@ -43,7 +43,8 @@ LTRResultIterator::~LTRResultIterator() = default; // Returns the null terminated UTF-8 encoded text string for the current // object at the given level. Use delete [] to free after use. char* LTRResultIterator::GetUTF8Text(PageIteratorLevel level) const { - if (it_->word() == nullptr) return nullptr; // Already at the end! + if (it_->word() == nullptr) + return nullptr; // Already at the end! STRING text; PAGE_RES_IT res_it(*it_); WERD_CHOICE* best_choice = res_it.word()->best_choice; @@ -70,7 +71,8 @@ char* LTRResultIterator::GetUTF8Text(PageIteratorLevel level) const { eop = res_it.block() != res_it.prev_block() || res_it.row()->row->para() != res_it.prev_row()->row->para(); } while (level != RIL_TEXTLINE && !eop); - if (eop) text += paragraph_separator_; + if (eop) + text += paragraph_separator_; } while (level == RIL_BLOCK && res_it.block() == res_it.prev_block()); } int length = text.length() + 1; @@ -92,7 +94,8 @@ void LTRResultIterator::SetParagraphSeparator(const char* new_para) { // Returns the mean confidence of the current object at the given level. // The number should be interpreted as a percent probability. (0.0f-100.0f) float LTRResultIterator::Confidence(PageIteratorLevel level) const { - if (it_->word() == nullptr) return 0.0f; // Already at the end! + if (it_->word() == nullptr) + return 0.0f; // Already at the end! float mean_certainty = 0.0f; int certainty_count = 0; PAGE_RES_IT res_it(*it_); @@ -211,18 +214,23 @@ const char* LTRResultIterator::WordRecognitionLanguage() const { // Return the overall directionality of this word. StrongScriptDirection LTRResultIterator::WordDirection() const { - if (it_->word() == nullptr) return DIR_NEUTRAL; + if (it_->word() == nullptr) + return DIR_NEUTRAL; bool has_rtl = it_->word()->AnyRtlCharsInWord(); bool has_ltr = it_->word()->AnyLtrCharsInWord(); - if (has_rtl && !has_ltr) return DIR_RIGHT_TO_LEFT; - if (has_ltr && !has_rtl) return DIR_LEFT_TO_RIGHT; - if (!has_ltr && !has_rtl) return DIR_NEUTRAL; + if (has_rtl && !has_ltr) + return DIR_RIGHT_TO_LEFT; + if (has_ltr && !has_rtl) + return DIR_LEFT_TO_RIGHT; + if (!has_ltr && !has_rtl) + return DIR_NEUTRAL; return DIR_MIX; } // Returns true if the current word was found in a dictionary. bool LTRResultIterator::WordIsFromDictionary() const { - if (it_->word() == nullptr) return false; // Already at the end! + if (it_->word() == nullptr) + return false; // Already at the end! int permuter = it_->word()->best_choice->permuter(); return permuter == SYSTEM_DAWG_PERM || permuter == FREQ_DAWG_PERM || permuter == USER_DAWG_PERM; @@ -230,13 +238,15 @@ bool LTRResultIterator::WordIsFromDictionary() const { // Returns the number of blanks before the current word. int LTRResultIterator::BlanksBeforeWord() const { - if (it_->word() == nullptr) return 1; + if (it_->word() == nullptr) + return 1; return it_->word()->word->space(); } // Returns true if the current word is numeric. bool LTRResultIterator::WordIsNumeric() const { - if (it_->word() == nullptr) return false; // Already at the end! + if (it_->word() == nullptr) + return false; // Already at the end! int permuter = it_->word()->best_choice->permuter(); return permuter == NUMBER_PERM; } @@ -269,7 +279,8 @@ const char* LTRResultIterator::GetBlamerMisadaptionDebug() const { // Returns true if a truth string was recorded for the current word. bool LTRResultIterator::HasTruthString() const { - if (it_->word() == nullptr) return false; // Already at the end! + if (it_->word() == nullptr) + return false; // Already at the end! if (it_->word()->blamer_bundle == nullptr || it_->word()->blamer_bundle->NoTruth()) { return false; // no truth information for this word @@ -280,7 +291,8 @@ bool LTRResultIterator::HasTruthString() const { // Returns true if the given string is equivalent to the truth string for // the current word. bool LTRResultIterator::EquivalentToTruth(const char* str) const { - if (!HasTruthString()) return false; + if (!HasTruthString()) + return false; ASSERT_HOST(it_->word()->uch_set != nullptr); WERD_CHOICE str_wd(str, *(it_->word()->uch_set)); return it_->word()->blamer_bundle->ChoiceIsCorrect(&str_wd); @@ -289,7 +301,8 @@ bool LTRResultIterator::EquivalentToTruth(const char* str) const { // Returns the null terminated UTF-8 encoded truth string for the current word. // Use delete [] to free after use. char* LTRResultIterator::WordTruthUTF8Text() const { - if (!HasTruthString()) return nullptr; + if (!HasTruthString()) + return nullptr; STRING truth_text = it_->word()->blamer_bundle->TruthString(); int length = truth_text.length() + 1; char* result = new char[length]; @@ -300,7 +313,8 @@ char* LTRResultIterator::WordTruthUTF8Text() const { // Returns the null terminated UTF-8 encoded normalized OCR string for the // current word. Use delete [] to free after use. char* LTRResultIterator::WordNormedUTF8Text() const { - if (it_->word() == nullptr) return nullptr; // Already at the end! + if (it_->word() == nullptr) + return nullptr; // Already at the end! STRING ocr_text; WERD_CHOICE* best_choice = it_->word()->best_choice; const UNICHARSET* unicharset = it_->word()->uch_set; @@ -317,8 +331,10 @@ char* LTRResultIterator::WordNormedUTF8Text() const { // Returns a pointer to serialized choice lattice. // Fills lattice_size with the number of bytes in lattice data. const char* LTRResultIterator::WordLattice(int* lattice_size) const { - if (it_->word() == nullptr) return nullptr; // Already at the end! - if (it_->word()->blamer_bundle == nullptr) return nullptr; + if (it_->word() == nullptr) + return nullptr; // Already at the end! + if (it_->word()->blamer_bundle == nullptr) + return nullptr; *lattice_size = it_->word()->blamer_bundle->lattice_size(); return it_->word()->blamer_bundle->lattice_data(); } @@ -357,11 +373,15 @@ ChoiceIterator::ChoiceIterator(const LTRResultIterator& result_it) { oemLSTM_ = word_res_->tesseract->AnyLSTMLang(); oemLegacy_ = word_res_->tesseract->AnyTessLang(); rating_coefficient_ = word_res_->tesseract->lstm_rating_coefficient; - BLOB_CHOICE_LIST* choices = nullptr; + blanks_before_word_ = result_it.BlanksBeforeWord(); + BLOB_CHOICE_LIST* choices = nullptr; tstep_index_ = &result_it.blob_index_; if (oemLSTM_ && !word_res_->CTC_symbol_choices.empty()) { + if (strcmp(word_res_->CTC_symbol_choices[0][0].first, " ")) { + blanks_before_word_ = 0; + } auto index = *tstep_index_; - if (word_res_->leading_space) ++index; + index += blanks_before_word_; if (index < word_res_->CTC_symbol_choices.size()) { LSTM_choices_ = &word_res_->CTC_symbol_choices[index]; filterSpaces(); @@ -379,7 +399,9 @@ ChoiceIterator::ChoiceIterator(const LTRResultIterator& result_it) { LSTM_choice_it_ = LSTM_choices_->begin(); } } -ChoiceIterator::~ChoiceIterator() { delete choice_it_; } +ChoiceIterator::~ChoiceIterator() { + delete choice_it_; +} // Moves to the next choice for the symbol and returns false if there // are none left. @@ -393,7 +415,8 @@ bool ChoiceIterator::Next() { return true; } } else { - if (choice_it_ == nullptr) return false; + if (choice_it_ == nullptr) + return false; choice_it_->forward(); return !choice_it_->cycled_list(); } @@ -406,7 +429,8 @@ const char* ChoiceIterator::GetUTF8Text() const { std::pair choice = *LSTM_choice_it_; return choice.first; } else { - if (choice_it_ == nullptr) return nullptr; + if (choice_it_ == nullptr) + return nullptr; UNICHAR_ID id = choice_it_->data()->unichar_id(); return word_res_->uch_set->id_to_unichar_ext(id); } @@ -416,15 +440,16 @@ const char* ChoiceIterator::GetUTF8Text() const { // data. If only LSTM traineddata is used the value range is 0.0f - 1.0f. All // choices for one symbol should roughly add up to 1.0f. // If only traineddata of the legacy engine is used, the number should be -// interpreted as a percent probability. (0.0f-100.0f) In this case probabilities -// won't add up to 100. Each one stands on its own. +// interpreted as a percent probability. (0.0f-100.0f) In this case +// probabilities won't add up to 100. Each one stands on its own. float ChoiceIterator::Confidence() const { float confidence; if (oemLSTM_ && LSTM_choices_ != nullptr && !LSTM_choices_->empty()) { std::pair choice = *LSTM_choice_it_; confidence = 100 - rating_coefficient_ * choice.second; } else { - if (choice_it_ == nullptr) return 0.0f; + if (choice_it_ == nullptr) + return 0.0f; confidence = 100 + 5 * choice_it_->data()->certainty(); } return ClipToRange(confidence, 0.0f, 100.0f); @@ -433,16 +458,16 @@ float ChoiceIterator::Confidence() const { // Returns the set of timesteps which belong to the current symbol std::vector>>* ChoiceIterator::Timesteps() const { - if (word_res_->symbol_steps.empty() || !oemLSTM_) return nullptr; - if (word_res_->leading_space) { - return &word_res_->symbol_steps[*(tstep_index_) + 1]; - } else { - return &word_res_->symbol_steps[*tstep_index_]; + int offset = *tstep_index_ + blanks_before_word_; + if (offset >= word_res_->segmented_timesteps.size() || !oemLSTM_) { + return nullptr; } + return &word_res_->segmented_timesteps[*tstep_index_ + blanks_before_word_]; } void ChoiceIterator::filterSpaces() { - if (LSTM_choices_->empty()) return; + if (LSTM_choices_->empty()) + return; std::vector>::iterator it; for (it = LSTM_choices_->begin(); it != LSTM_choices_->end();) { if (!strcmp(it->first, " ")) { diff --git a/src/ccmain/ltrresultiterator.h b/src/ccmain/ltrresultiterator.h index d0e5eaccc0..ecc5a668aa 100644 --- a/src/ccmain/ltrresultiterator.h +++ b/src/ccmain/ltrresultiterator.h @@ -239,6 +239,8 @@ class ChoiceIterator { bool oemLegacy_; // regulates the rating granularity double rating_coefficient_; + // leading blanks + int blanks_before_word_; }; } // namespace tesseract. diff --git a/src/ccmain/resultiterator.cpp b/src/ccmain/resultiterator.cpp index f896857374..3f36078fc4 100644 --- a/src/ccmain/resultiterator.cpp +++ b/src/ccmain/resultiterator.cpp @@ -639,10 +639,10 @@ char* ResultIterator::GetUTF8Text(PageIteratorLevel level) const { strncpy(result, text.string(), length); return result; } -std::vector>>* +std::vector>>>* ResultIterator::GetRawLSTMTimesteps() const { if (it_->word() != nullptr) { - return &it_->word()->raw_timesteps; + return &it_->word()->segmented_timesteps; } else { return nullptr; } @@ -650,15 +650,6 @@ ResultIterator::GetRawLSTMTimesteps() const { std::vector>>* ResultIterator::GetBestLSTMSymbolChoices() const { - if (it_->word() != nullptr) { - return &it_->word()->accumulated_timesteps; - } else { - return nullptr; - } -} - -std::vector>>* -ResultIterator::GetBestCTCSymbolChoices() const { if (it_->word() != nullptr) { return &it_->word()->CTC_symbol_choices; } else { @@ -666,15 +657,6 @@ ResultIterator::GetBestCTCSymbolChoices() const { } } -std::vector>>>* -ResultIterator::GetSegmentedLSTMTimesteps() const { - if (it_->word() != nullptr) { - return &it_->word()->symbol_steps; - } else { - return nullptr; - } -} - void ResultIterator::AppendUTF8WordText(STRING* text) const { if (!it_->word()) return; diff --git a/src/ccmain/resultiterator.h b/src/ccmain/resultiterator.h index bffe8969a6..445305b722 100644 --- a/src/ccmain/resultiterator.h +++ b/src/ccmain/resultiterator.h @@ -100,14 +100,10 @@ class TESS_API ResultIterator : public LTRResultIterator { /** * Returns the LSTM choices for every LSTM timestep for the current word. */ - virtual std::vector>>* + virtual std::vector>>>* GetRawLSTMTimesteps() const; virtual std::vector>>* GetBestLSTMSymbolChoices() const; - virtual std::vector>>* - GetBestCTCSymbolChoices() const; - virtual std::vector>>>* - GetSegmentedLSTMTimesteps() const; /** * Return whether the current paragraph's dominant reading direction diff --git a/src/ccmain/tesseractclass.cpp b/src/ccmain/tesseractclass.cpp index 9f847893f6..8a5f7fc194 100644 --- a/src/ccmain/tesseractclass.cpp +++ b/src/ccmain/tesseractclass.cpp @@ -522,19 +522,15 @@ Tesseract::Tesseract() this->params()), INT_MEMBER(lstm_choice_mode, 0, "Allows to include alternative symbols choices in the hOCR output. " - "Valid input values are 0, 1, 2 and 3. 0 is the default value. " + "Valid input values are 0, 1 and 2. 0 is the default value. " "With 1 the alternative symbol choices per timestep are included. " - "With 2 the alternative symbol choices are accumulated per " - "character. " - "With 3 the alternative symbol choices per timestep are included " - "and separated by the suggested segmentation of Tesseract. " - "With 4 alternative symbol choices are extracted from the CTC " + "With 2 alternative symbol choices are extracted from the CTC " "process instead of the lattice. The choices are mapped per " "character.", this->params()), INT_MEMBER( - lstm_choice_amount, 5, - "Sets the number of choices one get per character in " + lstm_choice_iterations, 5, + "Sets the number of cascading iterations for the Beamsearch in " "lstm_choice_mode. Note that lstm_choice_mode must be set to a " "value greater than 0 to produce results.", this->params()), diff --git a/src/ccmain/tesseractclass.h b/src/ccmain/tesseractclass.h index 518b8d193d..834d4a41cf 100644 --- a/src/ccmain/tesseractclass.h +++ b/src/ccmain/tesseractclass.h @@ -1086,17 +1086,13 @@ class Tesseract : public Wordrec { INT_VAR_H(lstm_choice_mode, 0, "Allows to include alternative symbols choices in the hOCR " "output. " - "Valid input values are 0, 1, 2 and 3. 0 is the default value. " + "Valid input values are 0, 1 and 2. 0 is the default value. " "With 1 the alternative symbol choices per timestep are included. " - "With 2 the alternative symbol choices are accumulated per " - "character. " - "With 3 the alternative symbol choices per timestep are included " - "and separated by the suggested segmentation of Tesseract. " - "With 4 alternative symbol choices are extracted from the CTC " + "With 2 the alternative symbol choices are extracted from the CTC " "process instead of the lattice. The choices are mapped per " "character."); - INT_VAR_H(lstm_choice_amount, 5, - "Sets the number of choices one get per character in " + INT_VAR_H(lstm_choice_iterations, 5, + "Sets the number of cascading iterations for the Beamsearch in " "lstm_choice_mode. Note that lstm_choice_mode must be set to " "a value greater than 0 to produce results."); double_VAR_H(lstm_rating_coefficient, 5, diff --git a/src/ccstruct/pageres.h b/src/ccstruct/pageres.h index bf4359c6a6..9c543d51a1 100644 --- a/src/ccstruct/pageres.h +++ b/src/ccstruct/pageres.h @@ -219,14 +219,16 @@ class WERD_RES : public ELIST_LINK { // blob i and blob i+1. GenericVector blob_gaps; // Stores the lstm choices of every timestep - std::vector>> raw_timesteps; - std::vector>> accumulated_timesteps; - std::vector>>> - symbol_steps; + std::vector>> timesteps; + // Stores the lstm choices of every timestep segmented by character + std::vector>>> segmented_timesteps; //Symbolchoices aquired during CTC std::vector>> CTC_symbol_choices; // Stores if the timestep vector starts with a space bool leading_space = false; + // Stores value when the word ends + int end; // Ratings matrix contains classifier choices for each classified combination // of blobs. The dimension is the same as the number of blobs in chopped_word // and the leading diagonal corresponds to classifier results of the blobs diff --git a/src/lstm/lstmrecognizer.cpp b/src/lstm/lstmrecognizer.cpp index 13e1778c3a..39755d5c26 100644 --- a/src/lstm/lstmrecognizer.cpp +++ b/src/lstm/lstmrecognizer.cpp @@ -208,15 +208,22 @@ void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert, lstm_choice_mode); search_->extractSymbolChoices(&GetUnicharset()); } - int ctc_it = 0; + search_->segmentTimestepsByCharacters(); + int char_it = 0; for (int i = 0; i < words->size(); ++i) { - for (int j = 0; j < words->get(i)->accumulated_timesteps.size(); ++j) { - if (ctc_it < search_->ctc_choices.size()) + for (int j = 0; j < words->get(i)->end; ++j) { + if (char_it < search_->ctc_choices.size()) words->get(i)->CTC_symbol_choices.push_back( - search_->ctc_choices[ctc_it]); - ++ctc_it; + search_->ctc_choices[char_it]); + if (char_it < search_->segmentedTimesteps.size()) + words->get(i)->segmented_timesteps.push_back( + search_->segmentedTimesteps[char_it]); + ++char_it; } + words->get(i)->timesteps = search_->combineSegmentedTimesteps( + &words->get(i)->segmented_timesteps); } + search_->segmentedTimesteps.clear(); search_->ctc_choices.clear(); search_->excludedUnichars.clear(); } diff --git a/src/lstm/recodebeam.cpp b/src/lstm/recodebeam.cpp index 105bdd943f..af91310287 100644 --- a/src/lstm/recodebeam.cpp +++ b/src/lstm/recodebeam.cpp @@ -119,8 +119,8 @@ void RecodeBeamSearch::DecodeSecondaryBeams(const NetworkIO& output, { ++bucketNumber; } - ComputeSecTopN(&(excludedUnichars)[bucketNumber], output.f(t), output.NumFeatures(), - kBeamWidths[0]); + ComputeSecTopN(&(excludedUnichars)[bucketNumber], output.f(t), + output.NumFeatures(), kBeamWidths[0]); DecodeSecondaryStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert, charset); } @@ -154,6 +154,28 @@ void RecodeBeamSearch::SaveMostCertainChoices(const float* outputs, timesteps.push_back(choices); } +void RecodeBeamSearch::segmentTimestepsByCharacters() { + for (int i = 1; i < character_boundaries_.size(); ++i){ + std::vector>> segment; + for (int j = character_boundaries_[i - 1]; j < character_boundaries_[i]; ++j) { + segment.push_back(timesteps[j]); + } + segmentedTimesteps.push_back(segment); + } +} +std::vector>> + RecodeBeamSearch::combineSegmentedTimesteps( + std::vector>>>* + segmentedTimesteps) { + std::vector>> combined_timesteps; + for (int i = 0; i < segmentedTimesteps->size(); ++i){ + for (int j = 0; j < (*segmentedTimesteps)[i].size(); ++j) { + combined_timesteps.push_back((*segmentedTimesteps)[i][j]); + } + } + return combined_timesteps; +} + void RecodeBeamSearch::calculateCharBoundaries(std::vector* starts, std::vector* ends, std::vector* char_bounds_, @@ -218,8 +240,6 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box, GenericVector xcoords; GenericVector best_nodes; GenericVector second_nodes; - std::deque> best_choices; - std::deque> best_choices_acc; character_boundaries_.clear(); ExtractBestPaths(&best_nodes, &second_nodes); if (debug) { @@ -230,22 +250,11 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box, DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings, xcoords); } - int timestepEndRaw = 0; - int timestepEnd = 0; - int timestepEnd_acc = 0; // If lstm choice mode is required in granularity level 2, it stores the x // Coordinates of every chosen character, to match the alternative choices to // it. ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords, - &character_boundaries_, &best_choices, - &best_choices_acc); - if (lstm_choice_mode) { - if (best_choices.size() > 0) { - timestepEnd = std::get<1>(best_choices.front()); - timestepEnd_acc = std::get<1>(best_choices_acc.front()); - best_choices_acc.pop_front(); - } - } + &character_boundaries_); int num_ids = unichar_ids.size(); if (debug) { DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings, @@ -277,71 +286,6 @@ void RecodeBeamSearch::ExtractBestPathAsWords(const TBOX& line_box, InitializeWord(leading_space, line_box, word_start, word_end, std::min(space_cert, prev_space_cert), unicharset, xcoords, scale_factor); - if (lstm_choice_mode) { - for (size_t i = timestepEndRaw; i < xcoords[word_end]; i++) { - word_res->raw_timesteps.push_back(timesteps[i]); - } - timestepEndRaw = xcoords[word_end]; - // Accumulated Timesteps (choice mode 2 processing) - float sum = 0; - std::vector> choice_pairs; - for (size_t i = timestepEnd_acc; i < xcoords[word_end]; i++) { - for (std::pair choice : timesteps[i]) { - if (std::strcmp(choice.first, "")) { - sum += choice.second; - choice_pairs.push_back(choice); - } - } - if ((best_choices_acc.size() > 0 && - i == std::get<1>(best_choices_acc.front()) - 1) || - i == xcoords[word_end] - 1) { - std::map summed_propabilities; - for (auto & choice_pair : choice_pairs) { - summed_propabilities[choice_pair.first] += choice_pair.second; - } - std::vector> accumulated_timestep; - int pos; - for (auto& summed_propability : summed_propabilities) { - if (sum == 0) break; - summed_propability.second /= sum; - pos = 0; - while (accumulated_timestep.size() > pos && - accumulated_timestep[pos].second > - summed_propability.second) { - pos++; - } - accumulated_timestep.insert( - accumulated_timestep.begin() + pos, - std::pair(summed_propability.first, - summed_propability.second)); - } - if (best_choices_acc.size() > 0) { - best_choices_acc.pop_front(); - } - choice_pairs.clear(); - word_res->accumulated_timesteps.push_back(accumulated_timestep); - sum = 0; - } - } - timestepEnd_acc = xcoords[word_end]; - // Symbol Step (choice mode 3 processing) - std::vector>> currentSymbol; - for (size_t i = timestepEnd; i < xcoords[word_end]; i++) { - if (i == std::get<1>(best_choices.front())) { - if (currentSymbol.size() > 0) { - word_res->symbol_steps.push_back(currentSymbol); - currentSymbol.clear(); - } - const char* leadCharacter = - unicharset->id_to_unichar_ext(std::get<0>(best_choices.front())); - if (!strcmp(leadCharacter, " ")) word_res->leading_space = true; - if (best_choices.size() > 1) best_choices.pop_front(); - } - currentSymbol.push_back(timesteps[i]); - } - word_res->symbol_steps.push_back(currentSymbol); - timestepEnd = xcoords[word_end]; - } for (int i = word_start; i < word_end; ++i) { auto* choices = new BLOB_CHOICE_LIST; BLOB_CHOICE_IT bc_it(choices); @@ -602,9 +546,7 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( const GenericVector& best_nodes, GenericVector* unichar_ids, GenericVector* certs, GenericVector* ratings, GenericVector* xcoords, - std::vector* character_boundaries, - std::deque>* best_choices, - std::deque>* best_choices_acc) { + std::vector* character_boundaries) { unichar_ids->truncate(0); certs->truncate(0); ratings->truncate(0); @@ -615,8 +557,6 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( int t = 0; int width = best_nodes.size(); while (t < width) { - int id; - int tposition; double certainty = 0.0; double rating = 0.0; while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) { @@ -638,10 +578,6 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( } unichar_ids->push_back(unichar_id); xcoords->push_back(t); - if (best_choices != nullptr) { - tposition = t; - id = unichar_id; - } do { double cert = best_nodes[t++]->certainty; // Special-case NO-PERM space to forget the certainty of the previous @@ -659,10 +595,6 @@ void RecodeBeamSearch::ExtractPathAsUnicharIds( if (certainty < certs->back()) certs->back() = certainty; ratings->back() += rating; } - if (best_choices != nullptr) { - best_choices->push_back(std::tuple(id, tposition)); - best_choices_acc->push_back(std::tuple(id, tposition)); - } } starts.push_back(width); if (character_boundaries != nullptr) { @@ -696,6 +628,7 @@ WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space, WERD* word = new WERD(&blobs, leading_space, nullptr); // Make a WERD_RES from the word. auto* word_res = new WERD_RES(word); + word_res->end = word_end - word_start + leading_space; word_res->uch_set = unicharset; word_res->combination = true; // Give it ownership of the word. word_res->space_certainty = space_certainty; @@ -742,7 +675,8 @@ void RecodeBeamSearch::ComputeSecTopN(std::unordered_set* exList, second_code_ = -1; top_heap_.clear(); for (int i = 0; i < num_outputs; ++i) { - if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) && !exList->count(i)) { + if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) + && !exList->count(i)) { TopPair entry(outputs[i], i); top_heap_.Push(&entry); if (top_heap_.size() > top_n) top_heap_.Pop(&entry); diff --git a/src/lstm/recodebeam.h b/src/lstm/recodebeam.h index 75ea3efbd6..6fcab6981d 100644 --- a/src/lstm/recodebeam.h +++ b/src/lstm/recodebeam.h @@ -227,10 +227,21 @@ class RecodeBeamSearch { // Generates debug output of the content of the beams after a Decode. void PrintBeam2(bool uids, int num_outputs, const UNICHARSET* charset, bool secondary) const; + // Segments the timestep bundle by the character_boundaries. + void segmentTimestepsByCharacters(); + std::vector>> + // Unions the segmented timestep character bundles to one big bundle. + combineSegmentedTimesteps( + std::vector>>>* + segmentedTimesteps); // Stores the alternative characters of every timestep together with their // probability. std::vector< std::vector>> timesteps; + std::vector>>> + segmentedTimesteps; + // Stores the character choices found in the ctc algorithm std::vector>> ctc_choices; + // Stores all unicharids which are excluded for future iterations std::vector> excludedUnichars; // Stores the character boundaries regarding timesteps. std::vector character_boundaries_; @@ -301,9 +312,7 @@ class RecodeBeamSearch { const GenericVector& best_nodes, GenericVector* unichar_ids, GenericVector* certs, GenericVector* ratings, GenericVector* xcoords, - std::vector* character_boundaries = nullptr, - std::deque>* best_choices = nullptr, - std::deque>* best_choices_acc = nullptr); + std::vector* character_boundaries = nullptr); // Sets up a word with the ratings matrix and fake blobs with boxes in the // right places.