diff --git a/Source/Engine/McBopomofoLM.cpp b/Source/Engine/McBopomofoLM.cpp index c4bafb47..7c7ef151 100644 --- a/Source/Engine/McBopomofoLM.cpp +++ b/Source/Engine/McBopomofoLM.cpp @@ -23,8 +23,9 @@ #include "McBopomofoLM.h" #include -#include #include +#include +#include namespace McBopomofo { @@ -136,24 +137,18 @@ bool McBopomofoLM::hasUnigrams(const std::string& key) return getUnigrams(key).size() > 0; } -std::string McBopomofoLM::getReading(const std::string_view& value) +std::string McBopomofoLM::getReading(const std::string& value) { - std::vector records = m_languageModel.getReadings(value); - - double highScore = -DBL_MAX; - std::string highScoringValue; - for (std::string record : records) { - std::vector parts = split(record, ' '); - if (parts.size() == 3) { - double score = std::stod(std::string(parts[2])); - if (score > highScore) { - highScoringValue = std::string(parts[0]); - highScore = score; - } + std::vector foundReadings = m_languageModel.getReadings(value); + double topScore = std::numeric_limits::lowest(); + std::string topValue; + for (const auto& foundReading : foundReadings) { + if (foundReading.score > topScore) { + topValue = foundReading.reading; + topScore = foundReading.score; } } - - return highScoringValue; + return topValue; } void McBopomofoLM::setPhraseReplacementEnabled(bool enabled) @@ -222,15 +217,4 @@ bool McBopomofoLM::hasAssociatedPhrasesForKey(const std::string& key) return m_associatedPhrases.hasValuesForKey(key); } -std::vector McBopomofoLM::split(const std::string_view& str, char delim) { - std::vector strings; - size_t start; - size_t end = 0; - while ((start = str.find_first_not_of(delim, end)) != std::string_view::npos) { - end = str.find(delim, start); - strings.push_back(std::string_view(str.substr(start, end - start))); - } - return strings; -} - } // namespace McBopomofo diff --git a/Source/Engine/McBopomofoLM.h b/Source/Engine/McBopomofoLM.h index f1f43ef1..870aab91 100644 --- a/Source/Engine/McBopomofoLM.h +++ b/Source/Engine/McBopomofoLM.h @@ -105,11 +105,8 @@ class McBopomofoLM : public Formosa::Gramambular2::LanguageModel { const std::vector associatedPhrasesForKey(const std::string& key); bool hasAssociatedPhrasesForKey(const std::string& key); - /// Returns a list of readings that match a given value. - /// @param value A string representing the text to look up reading candidates for. For example, - /// if you pass "說", it returns a list of records that include ㄕㄨㄛ, ㄕㄨㄟˋ, and ㄩㄝˋ. - /// @return Best reading found for the string, or an empty string if no matches are found. - std::string getReading(const std::string_view& value); + /// Returns the top-scored reading from the base model, given the value. + std::string getReading(const std::string& value); protected: /// Filters and converts the input unigrams and return a new list of unigrams. @@ -123,12 +120,6 @@ class McBopomofoLM : public Formosa::Gramambular2::LanguageModel { const std::unordered_set& excludedValues, std::unordered_set& insertedValues); - /// Splits a string into parts - /// @param str The string to split. - /// @param delim Delimiter character in the string to split on. - /// @return vector of split-up strings - std::vector split(const std::string_view& str, char delim); - ParselessLM m_languageModel; UserPhrasesLM m_userPhrases; UserPhrasesLM m_excludedPhrases; diff --git a/Source/Engine/ParselessLM.cpp b/Source/Engine/ParselessLM.cpp index d5cc3073..0f4266ed 100644 --- a/Source/Engine/ParselessLM.cpp +++ b/Source/Engine/ParselessLM.cpp @@ -145,11 +145,53 @@ bool McBopomofo::ParselessLM::hasUnigrams(const std::string& key) return db_->findFirstMatchingLine(key + " ") != nullptr; } -std::vector McBopomofo::ParselessLM::getReadings(const std::string_view& value) +std::vector McBopomofo::ParselessLM::getReadings(const std::string& value) { if (db_ == nullptr) { - return std::vector(); + return std::vector(); } - - return db_->reverseFindRows(value); + + std::vector results; + + // We append a space so that we only find rows with the exact value. We + // are taking advantage of the fact that a well-form row in this LM must + // be in the format of "key value score". + std::string actualValue = value + " "; + + for (const auto& row : db_->reverseFindRows(actualValue)) { + std::string key; + double score = 0; + + // Move ahead until we encounter the first space. This is the key. + auto it = row.begin(); + while (it != row.end() && *it != ' ') { + ++it; + } + + key = std::string(row.begin(), it); + + // Read past the space. + if (it != row.end()) { + ++it; + } + + if (it != row.end()) { + // Now it is the start of the value portion, but we move ahead + // until we encounter the second space to skip this part. + while (it != row.end() && *it != ' ') { + ++it; + } + } + + // Read past the space. The remainder, if it exists, is the score. + if (it != row.end()) { + ++it; + } + + if (it != row.end()) { + score = std::stod(std::string(it, row.end())); + } + results.emplace_back(McBopomofo::ParselessLM::FoundReading { key, score }); + } + return results; } diff --git a/Source/Engine/ParselessLM.h b/Source/Engine/ParselessLM.h index f3751a1a..f215b9e1 100644 --- a/Source/Engine/ParselessLM.h +++ b/Source/Engine/ParselessLM.h @@ -45,7 +45,12 @@ class ParselessLM : public Formosa::Gramambular2::LanguageModel { const std::string& key) override; bool hasUnigrams(const std::string& key) override; - std::vector getReadings(const std::string_view& value); + struct FoundReading { + std::string reading; + double score; + }; + // Look up reading by value. This is specific to ParselessLM only. + std::vector getReadings(const std::string& value); private: int fd_ = -1; diff --git a/Source/Engine/ParselessLMTest.cpp b/Source/Engine/ParselessLMTest.cpp index d73aaf69..3694d5ba 100644 --- a/Source/Engine/ParselessLMTest.cpp +++ b/Source/Engine/ParselessLMTest.cpp @@ -53,6 +53,20 @@ TEST(ParselessLMTest, SanityCheckTest) unigrams = lm.getUnigrams("_punctuation_list"); ASSERT_GT(unigrams.size(), 0); + std::vector found_readings; + found_readings = lm.getReadings("不存在的詞"); + ASSERT_TRUE(found_readings.empty()); + + found_readings = lm.getReadings("讀音"); + ASSERT_EQ(found_readings.size(), 1); + + found_readings = lm.getReadings("鑰匙"); + ASSERT_GT(found_readings.size(), 1); + + found_readings = lm.getReadings("得"); + ASSERT_GT(found_readings.size(), 1); + ASSERT_EQ(found_readings[0].reading, "ㄉㄜˊ"); + lm.close(); } diff --git a/Source/Engine/ParselessPhraseDB.cpp b/Source/Engine/ParselessPhraseDB.cpp index 1602e026..f72a923b 100644 --- a/Source/Engine/ParselessPhraseDB.cpp +++ b/Source/Engine/ParselessPhraseDB.cpp @@ -172,7 +172,7 @@ std::vector ParselessPhraseDB::reverseFindRows( while (recordBegin < end_) { const char* ptr = recordBegin; - + // skip over the key to find the field separator while (ptr < end_ && *ptr != ' ') { ++ptr; @@ -181,19 +181,19 @@ std::vector ParselessPhraseDB::reverseFindRows( while (ptr < end_ && *ptr == ' ') { ++ptr; } - + // now walk to the end of this record const char* recordEnd = ptr; while (recordEnd < end_ && *recordEnd != '\n') { ++recordEnd; } - + if (ptr + value.length() < end_ && memcmp(ptr, value.data(), value.length()) == 0) { // prefix match, add entire record to return value rows.emplace_back(recordBegin, recordEnd - recordBegin); } - - // skip over the record separator. there should be just one, but loop just in case. + + // skip over to the next line start recordBegin = recordEnd; while (recordBegin < end_ && *recordBegin == '\n') { ++recordBegin; diff --git a/Source/Engine/ParselessPhraseDB.h b/Source/Engine/ParselessPhraseDB.h index 58815a2a..436a6cad 100644 --- a/Source/Engine/ParselessPhraseDB.h +++ b/Source/Engine/ParselessPhraseDB.h @@ -50,9 +50,12 @@ class ParselessPhraseDB { const char* findFirstMatchingLine(const std::string_view& key); - // Find the rows that (prefix-)match the value, useful for returning all the - // ways a phrase or character can be pronounced. Note that this is a potentially- - // slow linear search that cannot take advantage of the pre-sorting. + // Find the rows whose text past the key column plus the field separator + // is a prefix match of the given value. For example, if the row is + // "foo bar -1.00", the values "b", "ba", "bar", "bar ", "bar -1.00" are + // are valid prefix matches, where as the value "barr" isn't. This + // performs linear scan since, unlike lookup-by-key, it cannot take + // advantage of the fact that the underlying data is sorted by keys. std::vector reverseFindRows(const std::string_view& value); private: diff --git a/Source/Engine/ParselessPhraseDBTest.cpp b/Source/Engine/ParselessPhraseDBTest.cpp index 07c0b7d6..a116989a 100644 --- a/Source/Engine/ParselessPhraseDBTest.cpp +++ b/Source/Engine/ParselessPhraseDBTest.cpp @@ -195,4 +195,30 @@ TEST(ParselessPhraseDBTest, StressTest) } } +TEST(ParselessPhraseDBTest, LookUpByValue) +{ + std::string data = "a 1\nb 1 \nc 2\nd 3"; + ParselessPhraseDB db(data.c_str(), data.length()); + + std::vector rows; + rows = db.reverseFindRows("1"); + ASSERT_EQ(rows, (std::vector { "a 1", "b 1 " })); + + rows = db.reverseFindRows("2"); + ASSERT_EQ(rows, (std::vector { "c 2" })); + + // This is a quirk of the function, but is actually valid. + rows = db.reverseFindRows("2\n"); + ASSERT_EQ(rows, (std::vector { "c 2" })); + + rows = db.reverseFindRows("22"); + ASSERT_TRUE(rows.empty()); + + rows = db.reverseFindRows("3\n"); + ASSERT_TRUE(rows.empty()); + + rows = db.reverseFindRows("4"); + ASSERT_TRUE(rows.empty()); +} + }; // namespace McBopomofo diff --git a/Source/LanguageModelManager.mm b/Source/LanguageModelManager.mm index 9cef327e..92eb6033 100644 --- a/Source/LanguageModelManager.mm +++ b/Source/LanguageModelManager.mm @@ -311,7 +311,7 @@ + (nullable NSString *)readingFor:(NSString *)phrase { } std::string reading = gLanguageModelMcBopomofo.getReading(phrase.UTF8String); - return !reading.empty() ? [NSString stringWithCString:reading.c_str() encoding:NSUTF8StringEncoding] : nil; + return !reading.empty() ? [NSString stringWithUTF8String:reading.c_str()] : nil; } @end