Skip to content

Commit

Permalink
Add unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
omarmhaimdat committed Jan 29, 2023
1 parent 5212b17 commit d8e27fa
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 41 deletions.
30 changes: 15 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use pyo3::prelude::*;
use rayon::prelude::*;
use std::time::Instant;
use whatlang::{detect, detect_lang, detect_script};

use crate::utils::{colorize, get_progress_bar, lang_to_iso639_1, TermColor};
Expand Down Expand Up @@ -137,6 +138,17 @@ fn convert_to_py_lang(lang: whatlang::Lang) -> PyLang {
}
}

/// Detect language of a list of texts
/// # Arguments
/// * `texts` - A list of texts
/// * `n_jobs` - Number of cores to use, if <= 0 use all cores
/// # Example
/// ```python
/// >>> from whatlang import batch_detect
/// >>> texts = ["Hello world", "Bonjour le monde"]
/// >>> batch_detect(texts, 1)
/// [Language: eng - Script: Latn - Confidence: 0.999 - Is reliable: true, Language: fra - Script: Latn - Confidence: 0.999 - Is reliable: true]
/// ```
fn batch_detect(texts: Vec<&str>, n_jobs: i16) -> Vec<PyInfo> {
// Get number of cores
let mut n_cores: usize = num_cpus::get();
Expand Down Expand Up @@ -165,39 +177,27 @@ fn batch_detect(texts: Vec<&str>, n_jobs: i16) -> Vec<PyInfo> {
.collect();
});
results

// if parallel {
// texts
// .into_par_iter()
// .map(|text| convert_to_py_info(detect(text).unwrap()))
// .collect()
// } else {
// texts
// .into_iter()
// .map(|text| convert_to_py_info(detect(text).unwrap()))
// .collect()
// }
}

#[pyfunction]
#[pyo3(name = "detect")]
#[pyo3(text_signature = "(text: str) -> Info")]
#[pyo3(text_signature = "(text: str)")]
fn py_detect(text: &str) -> PyResult<PyInfo> {
let info = detect(text).unwrap();
Ok(convert_to_py_info(info))
}

#[pyfunction]
#[pyo3(name = "detect_script")]
#[pyo3(text_signature = "(text: str) -> Info")]
#[pyo3(text_signature = "(text: str)")]
fn py_detect_script(text: &str) -> PyResult<PyScript> {
let script = detect_script(text).unwrap();
Ok(convert_to_py_script(script))
}

#[pyfunction]
#[pyo3(name = "detect_lang")]
#[pyo3(text_signature = "(text: str) -> Lang")]
#[pyo3(text_signature = "(text: str)")]
fn py_detect_lang(text: &str) -> PyResult<PyLang> {
let lang = detect_lang(text).unwrap();
Ok(convert_to_py_lang(lang))
Expand Down
70 changes: 44 additions & 26 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,51 @@
"""
Test whatlang-pyo3 with unittest
"""

from whatlang import detect, detect_script, detect_lang, batch_detect
import time
import unittest

class TestWhatlang(unittest.TestCase):

def test_detect(self):
result = detect("Ceci est écrit en français")
self.assertEqual(result.lang, "fra")
self.assertGreater(result.confidence, 0.1)

def test_detect_script(self):
result = detect_script("Ĉu vi ne volas eklerni Esperanton? Bonvolu! Estas unu de la plej bonaj aferoj!")
self.assertEqual(result.name, "Latin")

def main():
def test_detect_lang(self):
result = detect_lang("Ceci est écrit en français")
self.assertEqual(result.lang, "fra")

result = detect("Ceci est écrit en français")
script = detect_script("Ĉu vi ne volas eklerni Esperanton? Bonvolu! Estas unu de la plej bonaj aferoj!")
lang = detect_lang("Ceci est écrit en français")
batch = batch_detect(["Ceci est écrit en français", "Ĉu vi ne volas eklerni Esperanton? Bonvolu! Estas unu de la plej bonaj aferoj!"])
print(result)
print(script)
print(lang)
def test_batch_detect(self):
result = batch_detect(["Ceci est écrit en français", "Ĉu vi ne volas eklerni Esperanton? Bonvolu! Estas unu de la plej bonaj aferoj!"])
self.assertEqual(result[0].lang, "fra")
self.assertGreater(result[0].confidence, 0.1)
self.assertEqual(result[1].lang, "epo")
self.assertGreater(result[1].confidence, 0.5)

def compare_batch_with_single_performance():
# Create a list of 10000 texts
from whatlang import detect, batch_detect
import time
n = 50_000
texts = ["Ceci est écrit en français"] * n
print("--------------------------Batch detect--------------------------")
start = time.perf_counter()
batch_detect(texts, n_jobs=-1)
end = time.perf_counter()
print(f"Batch detect for {n} texts took {end - start} seconds")
print("--------------------------Single detect--------------------------")
start = time.perf_counter()
for text in texts:
detect(text)
end = time.perf_counter()
print(f"Single detect for {n} texts took {end - start} seconds")
def test_performance(self):
# Create a list of 10000 texts
n = 10_000
texts = ["Ceci est écrit en français"] * n
print("\n--------------------------Batch detect--------------------------")
start = time.perf_counter()
batch_detect(texts, n_jobs=-1)
end = time.perf_counter()
batch = end - start
print(f"Batch detect for {n} texts took {batch} seconds")
print("--------------------------Single detect--------------------------\n")
start = time.perf_counter()
for text in texts:
detect(text)
end = time.perf_counter()
single = end - start
print(f"Single detect for {n} texts took {single} seconds")
self.assertGreaterEqual(single, batch)

if __name__ == "__main__":
compare_batch_with_single_performance()
unittest.main()

0 comments on commit d8e27fa

Please sign in to comment.