diff --git a/heliport-model/Cargo.toml b/heliport-model/Cargo.toml index f2d6f75..9844f22 100644 --- a/heliport-model/Cargo.toml +++ b/heliport-model/Cargo.toml @@ -10,3 +10,4 @@ strum = { version = "0.25", features = ["derive"] } strum_macros = "0.25" wyhash2 = "0.2.1" anyhow = "1.0" +rayon = "1.10" diff --git a/heliport-model/src/languagemodel.rs b/heliport-model/src/languagemodel.rs index efd2aa8..0bd22ed 100644 --- a/heliport-model/src/languagemodel.rs +++ b/heliport-model/src/languagemodel.rs @@ -10,6 +10,7 @@ use std::thread; use anyhow::{Context, Result, bail}; use bitcode; use log::{info, debug, warn}; +use rayon::prelude::*; use strum::{Display, EnumCount, IntoEnumIterator}; use strum_macros::EnumIter; @@ -87,6 +88,7 @@ impl ModelNgram { dic: HashMap::default(), model_type: model_type.clone(), }; + let model_repr = model_type.to_string(); // Open languagelist for this model let lang_list = fs::read_to_string(model_dir.join("languagelist")) @@ -99,7 +101,7 @@ impl ModelNgram { let lang_repr = lang.to_string().to_lowercase(); // Models may not have all the language codes supported by the library if !lang_list.contains(&lang_repr[..]) { - warn!("Language '{lang_repr}' not found in languagelist, omitting"); + warn!("{model_repr}: Language '{lang_repr}' not found in languagelist, omitting"); continue; } @@ -291,16 +293,26 @@ impl Index for Model { /// Binarize models and save in a path pub fn binarize(save_path: &Path, model_path: &Path) -> Result<()> { - for model_type in OrderNgram::iter() { - let type_repr = model_type.to_string(); - info!("Loading {type_repr} model"); - let model = ModelNgram::from_text(&model_path, model_type, None)?; - let size = model.dic.len(); - info!("Created {size} entries"); - let filename = save_path.join(format!("{type_repr}.bin")); - info!("Saving {type_repr} model"); - model.save(Path::new(&filename))?; + let orders: Vec<_ > = OrderNgram::iter().collect(); + + let results: Vec> = orders + .par_iter() + .panic_fuse() + .map(|model_type| -> Result<()> { + let type_repr = model_type.to_string(); + info!("{type_repr}: loading text model"); + let model = ModelNgram::from_text(&model_path, model_type.clone(), None)?; + let size = model.dic.len(); + let filename = save_path.join(format!("{type_repr}.bin")); + info!("{type_repr}: saving binarized model with {size} entries"); + model.save(Path::new(&filename)) + }).collect(); + + // If there is one error, propagate + for r in results { + let _ = r?; } + info!("Copying confidence thresholds file"); fs::copy( model_path.join(Model::CONFIDENCE_FILE),