diff --git a/src/lib.rs b/src/lib.rs index 6bf5cc8..47feb5a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,6 +73,39 @@ pub fn detect_language(text: &str) -> Lang { weights::LANGUAGES[lang_id] } +/// Returns a Vec of tuples (language, weight) ascending by weight +pub fn detect_multiple_languages(text: &str) -> Vec<(Lang, f32)> { + let mut scores: [f32; NUM_LANGUAGES] = Default::default(); + let mut num_features = 0.0f32; + emit_tokens( + text, + #[inline(always)] + |token| { + num_features += 1.0f32; + let bucket = token.to_hash() % DIMENSION as u32; + let idx = bucket as usize * NUM_LANGUAGES; + let per_language_scores = &weights::WEIGHTS[idx..idx + NUM_LANGUAGES]; + for i in 0..NUM_LANGUAGES { + scores[i] += per_language_scores[i]; + } + }, + ); + for i in 0..NUM_LANGUAGES { + // Ok so the sqrt(num_features) is not really the norm, but whatever. + scores[i] = scores[i] / num_features.sqrt() + weights::INTERCEPTS[i]; + } + let mut langs: Vec<(Lang, f32)> = Vec::new(); + for (i, score) in scores.iter().enumerate() { + langs.push((LANGUAGES[i], score.to_owned())) + } + + langs.sort_by(|(_, score_left), (_, score_right)| score_left.partial_cmp(score_right).unwrap()); + + langs.reverse(); + + langs +} + #[doc(hidden)] pub fn emit_tokens(text: &str, mut listener: impl FnMut(Feature)) { let mut prev = ' ' as u32; @@ -187,6 +220,7 @@ fn classify_codepoint(chr: char) -> u32 { #[cfg(test)] mod tests { use crate::detect_language; + use crate::detect_multiple_languages; use crate::emit_tokens; use crate::Feature; use crate::Lang; @@ -195,7 +229,7 @@ mod tests { assert!(text.is_ascii()); let mut bytes: [u8; 4] = [0u8; 4]; assert!(text.len() <= 4); - bytes[4-text.len()..].copy_from_slice(text.as_bytes()); + bytes[4 - text.len()..].copy_from_slice(text.as_bytes()); Feature::AsciiNGram(u32::from_be_bytes(bytes)) } @@ -207,22 +241,17 @@ mod tests { &tokens, &[ ascii_ngram_feature(" h"), - ascii_ngram_feature("he"), ascii_ngram_feature(" he"), - ascii_ngram_feature("el"), ascii_ngram_feature("hel"), ascii_ngram_feature(" hel"), - ascii_ngram_feature("ll"), ascii_ngram_feature("ell"), ascii_ngram_feature("hell"), - ascii_ngram_feature("lo"), ascii_ngram_feature("llo"), ascii_ngram_feature("ello"), - Feature::Unicode(' '), Feature::UnicodeClass(' '), Feature::Unicode('こ'), @@ -257,11 +286,76 @@ mod tests { assert_eq!(detect_language("Ciao, felice contribuente!"), Lang::Ita); // Spanish assert_eq!(detect_language("Hola feliz contribuyente"), Lang::Spa); - assert_eq!( - detect_language("¡Hola!"), - Lang::Spa - ); + assert_eq!(detect_language("¡Hola!"), Lang::Spa); // Portuguese assert_eq!(detect_language("Olá feliz contribuinte"), Lang::Por); } + + #[test] + fn test_detect_language_with_score() { + // English + assert!( + matches!(detect_multiple_languages("Hello, happy tax payer").first().unwrap(), + &(Lang::Eng, _)) + ); + + // French + assert!( + matches!(detect_multiple_languages("Bonjour joyeux contribuable").first().unwrap(), + &(Lang::Fra, _)) + ); + + // German + assert!( + matches!(detect_multiple_languages("Hallo glücklicher Steuerzahler").first().unwrap(), + &(Lang::Deu, _)) + ); + + // Japanese + assert!( + matches!(detect_multiple_languages("こんにちは幸せな税金納め").first().unwrap(), + &(Lang::Jpn, _)) + ); + + // Mandarin chinese + assert!( + matches!(detect_multiple_languages("你好幸福的纳税人").first().unwrap(), + &(Lang::Cmn, _)) + ); + + // Turkish + assert!( + matches!(detect_multiple_languages("Merhaba, mutlu vergi mükellefi").first().unwrap(), + &(Lang::Tur, _)) + ); + + // Dutch + assert!( + matches!(detect_multiple_languages("Hallo, blije belastingbetaler").first().unwrap(), + &(Lang::Nld, _)) + ); + + // Korean + assert!( + matches!(detect_multiple_languages("안녕하세요 행복한 납세자입니다").first().unwrap(), + &(Lang::Kor, _)) + ); + + // Italian + assert!( + matches!(detect_multiple_languages("Ciao, felice contribuente!").first().unwrap(), + &(Lang::Ita, _)) + ); + + // Spanish + assert!( + matches!(detect_multiple_languages("Hola feliz contribuyente").first().unwrap(), + &(Lang::Spa, _)) + ); + + assert!( + matches!(detect_multiple_languages("¡Hola!").first().unwrap(), + &(Lang::Spa, _)) + ); + } }