Skip to content

Commit

Permalink
Move binarization to a function in languagemodel.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
ZJaume committed Oct 31, 2024
1 parent c7b5ddc commit beef29f
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 51 deletions.
28 changes: 3 additions & 25 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ use std::fs;
use std::path::PathBuf;

use anyhow::Result;
use log::{info};
use strum::IntoEnumIterator;

use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram};
use heliport_model::languagemodel::binarize;

fn main() -> Result<(), std::io::Error> {
fn main() -> Result<()> {
let mut model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
model_path.push("LanguageModels");

Expand All @@ -26,25 +24,5 @@ fn main() -> Result<(), std::io::Error> {
);
println!("cargo:rerun-if-changed=build.rs");

//TODO parallelize
for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = ModelNgram::from_text(&model_path, model_type, None)
.unwrap();
let size = model.dic.len();
info!("Created {size} entries");
let filename = save_path.join(format!("{type_repr}.bin"));
info!("Saving {type_repr} model");
model.save(&filename).unwrap();
}
info!("Copying confidence thresholds file");
fs::copy(
model_path.join(Model::CONFIDENCE_FILE),
save_path.join(Model::CONFIDENCE_FILE),
).unwrap();

info!("Finished");

Ok(())
binarize(&save_path, &model_path)
}
25 changes: 24 additions & 1 deletion heliport-model/src/languagemodel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use std::thread;

use anyhow::{Context, Result, bail};
use bitcode;
use log::{debug, warn};
use log::{info, debug, warn};
use strum::{Display, EnumCount, IntoEnumIterator};
use strum_macros::EnumIter;

Expand Down Expand Up @@ -289,6 +289,29 @@ impl Index<usize> for Model {
}
}

/// Binarize models and save in a path
pub fn binarize(save_path: &Path, model_path: &Path) -> Result<()> {
for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = ModelNgram::from_text(&model_path, model_type, None)?;
let size = model.dic.len();
info!("Created {size} entries");
let filename = save_path.join(format!("{type_repr}.bin"));
info!("Saving {type_repr} model");
model.save(Path::new(&filename))?;
}
info!("Copying confidence thresholds file");
fs::copy(
model_path.join(Model::CONFIDENCE_FILE),
save_path.join(Model::CONFIDENCE_FILE),
)?;

info!("Saved models at '{}'", save_path.display());
info!("Finished");
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
29 changes: 4 additions & 25 deletions src/cli/binarize.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use std::path::{Path, PathBuf};
use std::fs;
use std::path::{PathBuf};
use std::process::exit;

use clap::Args;
use log::{error, warn, info};
use log::{error, warn};
use pyo3::prelude::*;
use strum::IntoEnumIterator;

use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram};
use heliport_model::languagemodel::{binarize, OrderNgram};
use crate::utils::Abort;
use crate::python::module_path;

Expand Down Expand Up @@ -39,26 +37,7 @@ impl BinarizeCmd {
exit(1);
}

for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = ModelNgram::from_text(&model_path, model_type, None)
.or_abort(1);
let size = model.dic.len();
info!("Created {size} entries");
let filename = save_path.join(format!("{type_repr}.bin"));
info!("Saving {type_repr} model");
model.save(Path::new(&filename)).or_abort(1);
}
info!("Copying confidence thresholds file");
fs::copy(
model_path.join(Model::CONFIDENCE_FILE),
save_path.join(Model::CONFIDENCE_FILE),
).or_abort(1);

info!("Saved models at '{}'", save_path.display());
info!("Finished");

binarize(&save_path, &model_path).or_abort(1);
Ok(())
}
}
Expand Down

0 comments on commit beef29f

Please sign in to comment.