From d8e27fa2c3929b7a5c231f307fbd2e333e3fd5dc Mon Sep 17 00:00:00 2001 From: Omar MHAIMDAT Date: Sun, 29 Jan 2023 20:15:34 +0100 Subject: [PATCH] Add unittests --- src/lib.rs | 30 +++++++++++------------ test.py | 70 ++++++++++++++++++++++++++++++++++-------------------- 2 files changed, 59 insertions(+), 41 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 451c66f..64f74b2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ use pyo3::prelude::*; use rayon::prelude::*; +use std::time::Instant; use whatlang::{detect, detect_lang, detect_script}; use crate::utils::{colorize, get_progress_bar, lang_to_iso639_1, TermColor}; @@ -137,6 +138,17 @@ fn convert_to_py_lang(lang: whatlang::Lang) -> PyLang { } } +/// Detect language of a list of texts +/// # Arguments +/// * `texts` - A list of texts +/// * `n_jobs` - Number of cores to use, if <= 0 use all cores +/// # Example +/// ```python +/// >>> from whatlang import batch_detect +/// >>> texts = ["Hello world", "Bonjour le monde"] +/// >>> batch_detect(texts, 1) +/// [Language: eng - Script: Latn - Confidence: 0.999 - Is reliable: true, Language: fra - Script: Latn - Confidence: 0.999 - Is reliable: true] +/// ``` fn batch_detect(texts: Vec<&str>, n_jobs: i16) -> Vec { // Get number of cores let mut n_cores: usize = num_cpus::get(); @@ -165,23 +177,11 @@ fn batch_detect(texts: Vec<&str>, n_jobs: i16) -> Vec { .collect(); }); results - - // if parallel { - // texts - // .into_par_iter() - // .map(|text| convert_to_py_info(detect(text).unwrap())) - // .collect() - // } else { - // texts - // .into_iter() - // .map(|text| convert_to_py_info(detect(text).unwrap())) - // .collect() - // } } #[pyfunction] #[pyo3(name = "detect")] -#[pyo3(text_signature = "(text: str) -> Info")] +#[pyo3(text_signature = "(text: str)")] fn py_detect(text: &str) -> PyResult { let info = detect(text).unwrap(); Ok(convert_to_py_info(info)) @@ -189,7 +189,7 @@ fn py_detect(text: &str) -> PyResult { #[pyfunction] #[pyo3(name = "detect_script")] -#[pyo3(text_signature = "(text: str) -> Info")] +#[pyo3(text_signature = "(text: str)")] fn py_detect_script(text: &str) -> PyResult { let script = detect_script(text).unwrap(); Ok(convert_to_py_script(script)) @@ -197,7 +197,7 @@ fn py_detect_script(text: &str) -> PyResult { #[pyfunction] #[pyo3(name = "detect_lang")] -#[pyo3(text_signature = "(text: str) -> Lang")] +#[pyo3(text_signature = "(text: str)")] fn py_detect_lang(text: &str) -> PyResult { let lang = detect_lang(text).unwrap(); Ok(convert_to_py_lang(lang)) diff --git a/test.py b/test.py index 3e6adcd..aa17ae0 100644 --- a/test.py +++ b/test.py @@ -1,33 +1,51 @@ +""" +Test whatlang-pyo3 with unittest +""" + from whatlang import detect, detect_script, detect_lang, batch_detect import time +import unittest + +class TestWhatlang(unittest.TestCase): + + def test_detect(self): + result = detect("Ceci est écrit en français") + self.assertEqual(result.lang, "fra") + self.assertGreater(result.confidence, 0.1) + + def test_detect_script(self): + result = detect_script("Ĉu vi ne volas eklerni Esperanton? Bonvolu! Estas unu de la plej bonaj aferoj!") + self.assertEqual(result.name, "Latin") -def main(): + def test_detect_lang(self): + result = detect_lang("Ceci est écrit en français") + self.assertEqual(result.lang, "fra") - result = detect("Ceci est écrit en français") - script = detect_script("Ĉu vi ne volas eklerni Esperanton? Bonvolu! Estas unu de la plej bonaj aferoj!") - lang = detect_lang("Ceci est écrit en français") - batch = batch_detect(["Ceci est écrit en français", "Ĉu vi ne volas eklerni Esperanton? Bonvolu! Estas unu de la plej bonaj aferoj!"]) - print(result) - print(script) - print(lang) + def test_batch_detect(self): + result = batch_detect(["Ceci est écrit en français", "Ĉu vi ne volas eklerni Esperanton? Bonvolu! Estas unu de la plej bonaj aferoj!"]) + self.assertEqual(result[0].lang, "fra") + self.assertGreater(result[0].confidence, 0.1) + self.assertEqual(result[1].lang, "epo") + self.assertGreater(result[1].confidence, 0.5) -def compare_batch_with_single_performance(): - # Create a list of 10000 texts - from whatlang import detect, batch_detect - import time - n = 50_000 - texts = ["Ceci est écrit en français"] * n - print("--------------------------Batch detect--------------------------") - start = time.perf_counter() - batch_detect(texts, n_jobs=-1) - end = time.perf_counter() - print(f"Batch detect for {n} texts took {end - start} seconds") - print("--------------------------Single detect--------------------------") - start = time.perf_counter() - for text in texts: - detect(text) - end = time.perf_counter() - print(f"Single detect for {n} texts took {end - start} seconds") + def test_performance(self): + # Create a list of 10000 texts + n = 10_000 + texts = ["Ceci est écrit en français"] * n + print("\n--------------------------Batch detect--------------------------") + start = time.perf_counter() + batch_detect(texts, n_jobs=-1) + end = time.perf_counter() + batch = end - start + print(f"Batch detect for {n} texts took {batch} seconds") + print("--------------------------Single detect--------------------------\n") + start = time.perf_counter() + for text in texts: + detect(text) + end = time.perf_counter() + single = end - start + print(f"Single detect for {n} texts took {single} seconds") + self.assertGreaterEqual(single, batch) if __name__ == "__main__": - compare_batch_with_single_performance() + unittest.main()