Skip to content

Commit

Permalink
Parallelize model binarization
Browse files Browse the repository at this point in the history
  • Loading branch information
ZJaume committed Nov 18, 2024
1 parent f1a7510 commit 3b76421
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
1 change: 1 addition & 0 deletions heliport-model/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ strum = { version = "0.25", features = ["derive"] }
strum_macros = "0.25"
wyhash2 = "0.2.1"
anyhow = "1.0"
rayon = "1.10"
32 changes: 22 additions & 10 deletions heliport-model/src/languagemodel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::thread;
use anyhow::{Context, Result, bail};
use bitcode;
use log::{info, debug, warn};
use rayon::prelude::*;
use strum::{Display, EnumCount, IntoEnumIterator};
use strum_macros::EnumIter;

Expand Down Expand Up @@ -87,6 +88,7 @@ impl ModelNgram {
dic: HashMap::default(),
model_type: model_type.clone(),
};
let model_repr = model_type.to_string();

// Open languagelist for this model
let lang_list = fs::read_to_string(model_dir.join("languagelist"))
Expand All @@ -99,7 +101,7 @@ impl ModelNgram {
let lang_repr = lang.to_string().to_lowercase();
// Models may not have all the language codes supported by the library
if !lang_list.contains(&lang_repr[..]) {
warn!("Language '{lang_repr}' not found in languagelist, omitting");
warn!("{model_repr}: Language '{lang_repr}' not found in languagelist, omitting");
continue;
}

Expand Down Expand Up @@ -291,16 +293,26 @@ impl Index<usize> for Model {

/// Binarize models and save in a path
pub fn binarize(save_path: &Path, model_path: &Path) -> Result<()> {
for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = ModelNgram::from_text(&model_path, model_type, None)?;
let size = model.dic.len();
info!("Created {size} entries");
let filename = save_path.join(format!("{type_repr}.bin"));
info!("Saving {type_repr} model");
model.save(Path::new(&filename))?;
let orders: Vec<_ > = OrderNgram::iter().collect();

let results: Vec<Result<_>> = orders
.par_iter()
.panic_fuse()
.map(|model_type| -> Result<()> {
let type_repr = model_type.to_string();
info!("{type_repr}: loading text model");
let model = ModelNgram::from_text(&model_path, model_type.clone(), None)?;
let size = model.dic.len();
let filename = save_path.join(format!("{type_repr}.bin"));
info!("{type_repr}: saving binarized model with {size} entries");
model.save(Path::new(&filename))
}).collect();

// If there is one error, propagate
for r in results {
let _ = r?;
}

info!("Copying confidence thresholds file");
fs::copy(
model_path.join(Model::CONFIDENCE_FILE),
Expand Down

0 comments on commit 3b76421

Please sign in to comment.