-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Include binarized model in the wheels (#7)
* Move languagemodel to a separated workspace package * Binarize models during package build * Put binarized models into the wheel * Need --force to binarize model if already exists * Do not publish models in Github releases Already included in the wheels * Disable download feature by default * Make sure empty dir heliport.data exists so models are created and included * [GA] Re-enable sccache and disable verbosity
- Loading branch information
Showing
16 changed files
with
240 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
use std::env; | ||
use std::fs; | ||
use std::path::PathBuf; | ||
|
||
use anyhow::Result; | ||
use log::{info}; | ||
use strum::IntoEnumIterator; | ||
|
||
use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram}; | ||
|
||
fn main() -> Result<(), std::io::Error> { | ||
let mut model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); | ||
model_path.push("LanguageModels"); | ||
|
||
let platlib_path = concat!( | ||
env!("CARGO_MANIFEST_DIR"), | ||
"/heliport.data/platlib/heliport", | ||
); | ||
fs::create_dir_all(&platlib_path)?; | ||
let save_path = PathBuf::from(&platlib_path); | ||
|
||
// Re-run build script if language models has changed | ||
println!( | ||
"cargo:rerun-if-changed={}", | ||
model_path.display() | ||
); | ||
println!("cargo:rerun-if-changed=build.rs"); | ||
|
||
//TODO parallelize | ||
for model_type in OrderNgram::iter() { | ||
let type_repr = model_type.to_string(); | ||
info!("Loading {type_repr} model"); | ||
let model = ModelNgram::from_text(&model_path, model_type, None) | ||
.unwrap(); | ||
let size = model.dic.len(); | ||
info!("Created {size} entries"); | ||
let filename = save_path.join(format!("{type_repr}.bin")); | ||
info!("Saving {type_repr} model"); | ||
model.save(&filename).unwrap(); | ||
} | ||
info!("Copying confidence thresholds file"); | ||
fs::copy( | ||
model_path.join(Model::CONFIDENCE_FILE), | ||
save_path.join(Model::CONFIDENCE_FILE), | ||
).unwrap(); | ||
|
||
info!("Finished"); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[package] | ||
name = "heliport-model" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
[dependencies] | ||
bitcode = "0.6" | ||
log = { version = "0.4" } | ||
strum = { version = "0.25", features = ["derive"] } | ||
strum_macros = "0.25" | ||
wyhash2 = "0.2.1" | ||
anyhow = "1.0" |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pub mod languagemodel; | ||
pub mod lang; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
* | ||
!.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
use std::path::{Path, PathBuf}; | ||
use std::fs; | ||
use std::process::exit; | ||
|
||
use clap::Args; | ||
use log::{error, warn, info}; | ||
use pyo3::prelude::*; | ||
use strum::IntoEnumIterator; | ||
|
||
use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram}; | ||
use crate::utils::Abort; | ||
use crate::python::module_path; | ||
|
||
#[derive(Args, Clone)] | ||
pub struct BinarizeCmd { | ||
#[arg(help="Input directory where ngram frequency files are located")] | ||
input_dir: Option<PathBuf>, | ||
#[arg(help="Output directory to place the binary files")] | ||
output_dir: Option<PathBuf>, | ||
#[arg(short, long, help="Force overwrite of output files if they already exist")] | ||
force: bool, | ||
} | ||
|
||
impl BinarizeCmd { | ||
pub fn cli(self) -> PyResult<()> { | ||
let model_path = self.input_dir.unwrap_or(PathBuf::from("./LanguageModels")); | ||
let save_path = self.output_dir.unwrap_or(module_path().unwrap()); | ||
|
||
// Fail and warn the use if there is already a model | ||
if !self.force && | ||
save_path.join( | ||
format!("{}.bin", OrderNgram::Word.to_string()) | ||
).exists() | ||
{ | ||
warn!("Binarized models are now included in the PyPi package, \ | ||
there is no need to binarize the model unless you are training a new one" | ||
); | ||
error!("Output model already exists, use '-f' to force overwrite"); | ||
exit(1); | ||
} | ||
|
||
for model_type in OrderNgram::iter() { | ||
let type_repr = model_type.to_string(); | ||
info!("Loading {type_repr} model"); | ||
let model = ModelNgram::from_text(&model_path, model_type, None) | ||
.or_abort(1); | ||
let size = model.dic.len(); | ||
info!("Created {size} entries"); | ||
let filename = save_path.join(format!("{type_repr}.bin")); | ||
info!("Saving {type_repr} model"); | ||
model.save(Path::new(&filename)).or_abort(1); | ||
} | ||
info!("Copying confidence thresholds file"); | ||
fs::copy( | ||
model_path.join(Model::CONFIDENCE_FILE), | ||
save_path.join(Model::CONFIDENCE_FILE), | ||
).or_abort(1); | ||
|
||
info!("Saved models at '{}'", save_path.display()); | ||
info!("Finished"); | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
use std::path::PathBuf; | ||
use std::env; | ||
|
||
use clap::Args; | ||
use pyo3::prelude::*; | ||
use log::info; | ||
use target; | ||
|
||
use crate::python::module_path; | ||
use crate::download; | ||
|
||
#[derive(Args, Clone)] | ||
pub struct DownloadCmd { | ||
#[arg(help="Path to download the model, defaults to the module path")] | ||
path: Option<PathBuf>, | ||
} | ||
|
||
impl DownloadCmd { | ||
pub fn cli(self) -> PyResult<()> { | ||
let download_path = self.path.unwrap_or(module_path().unwrap()); | ||
|
||
let url = format!( | ||
"https://github.com/ZJaume/{}/releases/download/v{}/models-{}-{}.tgz", | ||
env!("CARGO_PKG_NAME"), | ||
env!("CARGO_PKG_VERSION"), | ||
target::os(), | ||
target::arch()); | ||
|
||
download::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap(); | ||
info!("Finished"); | ||
|
||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.