Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WiP: Multiple language detection #15

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 104 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,39 @@ pub fn detect_language(text: &str) -> Lang {
weights::LANGUAGES[lang_id]
}

/// Returns a Vec of tuples (language, weight) ascending by weight
pub fn detect_multiple_languages(text: &str) -> Vec<(Lang, f32)> {
let mut scores: [f32; NUM_LANGUAGES] = Default::default();
let mut num_features = 0.0f32;
emit_tokens(
text,
#[inline(always)]
|token| {
num_features += 1.0f32;
let bucket = token.to_hash() % DIMENSION as u32;
let idx = bucket as usize * NUM_LANGUAGES;
let per_language_scores = &weights::WEIGHTS[idx..idx + NUM_LANGUAGES];
for i in 0..NUM_LANGUAGES {
scores[i] += per_language_scores[i];
}
},
);
for i in 0..NUM_LANGUAGES {
// Ok so the sqrt(num_features) is not really the norm, but whatever.
scores[i] = scores[i] / num_features.sqrt() + weights::INTERCEPTS[i];
}
let mut langs: Vec<(Lang, f32)> = Vec::new();
for (i, score) in scores.iter().enumerate() {
langs.push((LANGUAGES[i], score.to_owned()))
}

langs.sort_by(|(_, score_left), (_, score_right)| score_left.partial_cmp(score_right).unwrap());

langs.reverse();

langs
}

#[doc(hidden)]
pub fn emit_tokens(text: &str, mut listener: impl FnMut(Feature)) {
let mut prev = ' ' as u32;
Expand Down Expand Up @@ -187,6 +220,7 @@ fn classify_codepoint(chr: char) -> u32 {
#[cfg(test)]
mod tests {
use crate::detect_language;
use crate::detect_multiple_languages;
use crate::emit_tokens;
use crate::Feature;
use crate::Lang;
Expand All @@ -195,7 +229,7 @@ mod tests {
assert!(text.is_ascii());
let mut bytes: [u8; 4] = [0u8; 4];
assert!(text.len() <= 4);
bytes[4-text.len()..].copy_from_slice(text.as_bytes());
bytes[4 - text.len()..].copy_from_slice(text.as_bytes());
Feature::AsciiNGram(u32::from_be_bytes(bytes))
}

Expand All @@ -207,22 +241,17 @@ mod tests {
&tokens,
&[
ascii_ngram_feature(" h"),

ascii_ngram_feature("he"),
ascii_ngram_feature(" he"),

ascii_ngram_feature("el"),
ascii_ngram_feature("hel"),
ascii_ngram_feature(" hel"),

ascii_ngram_feature("ll"),
ascii_ngram_feature("ell"),
ascii_ngram_feature("hell"),

ascii_ngram_feature("lo"),
ascii_ngram_feature("llo"),
ascii_ngram_feature("ello"),

Feature::Unicode(' '),
Feature::UnicodeClass(' '),
Feature::Unicode('こ'),
Expand Down Expand Up @@ -257,11 +286,76 @@ mod tests {
assert_eq!(detect_language("Ciao, felice contribuente!"), Lang::Ita);
// Spanish
assert_eq!(detect_language("Hola feliz contribuyente"), Lang::Spa);
assert_eq!(
detect_language("¡Hola!"),
Lang::Spa
);
assert_eq!(detect_language("¡Hola!"), Lang::Spa);
// Portuguese
assert_eq!(detect_language("Olá feliz contribuinte"), Lang::Por);
}

#[test]
fn test_detect_language_with_score() {
// English
assert!(
matches!(detect_multiple_languages("Hello, happy tax payer").first().unwrap(),
&(Lang::Eng, _))
);

// French
assert!(
matches!(detect_multiple_languages("Bonjour joyeux contribuable").first().unwrap(),
&(Lang::Fra, _))
);

// German
assert!(
matches!(detect_multiple_languages("Hallo glücklicher Steuerzahler").first().unwrap(),
&(Lang::Deu, _))
);

// Japanese
assert!(
matches!(detect_multiple_languages("こんにちは幸せな税金納め").first().unwrap(),
&(Lang::Jpn, _))
);

// Mandarin chinese
assert!(
matches!(detect_multiple_languages("你好幸福的纳税人").first().unwrap(),
&(Lang::Cmn, _))
);

// Turkish
assert!(
matches!(detect_multiple_languages("Merhaba, mutlu vergi mükellefi").first().unwrap(),
&(Lang::Tur, _))
);

// Dutch
assert!(
matches!(detect_multiple_languages("Hallo, blije belastingbetaler").first().unwrap(),
&(Lang::Nld, _))
);

// Korean
assert!(
matches!(detect_multiple_languages("안녕하세요 행복한 납세자입니다").first().unwrap(),
&(Lang::Kor, _))
);

// Italian
assert!(
matches!(detect_multiple_languages("Ciao, felice contribuente!").first().unwrap(),
&(Lang::Ita, _))
);

// Spanish
assert!(
matches!(detect_multiple_languages("Hola feliz contribuyente").first().unwrap(),
&(Lang::Spa, _))
);

assert!(
matches!(detect_multiple_languages("¡Hola!").first().unwrap(),
&(Lang::Spa, _))
);
}
}