Skip to content

Commit

Permalink
Include binarized model in the wheels (#7)
Browse files Browse the repository at this point in the history
* Move languagemodel to a separated workspace package

* Binarize models during package build

* Put binarized models into the wheel

* Need --force to binarize model if already exists

* Do not publish models in Github releases

Already included in the wheels

* Disable download feature by default

* Make sure empty dir heliport.data exists so models are created and
included

* [GA] Re-enable sccache and disable verbosity
  • Loading branch information
ZJaume authored Oct 31, 2024
1 parent a00341a commit 12c77ee
Show file tree
Hide file tree
Showing 16 changed files with 240 additions and 130 deletions.
17 changes: 0 additions & 17 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,11 @@ jobs:
args: --release --out dist --find-interpreter
sccache: 'true'
manylinux: auto

- run: pip install dist/*-cp39-*.whl
- run: heliport binarize

- name: Upload wheels
uses: actions/upload-artifact@v4
with:
name: wheels-linux-${{ matrix.platform.target }}
path: dist
- name: Upload models
uses: actions/upload-artifact@v4
with:
name: models-linux-${{ matrix.platform.target }}
path: |
/opt/hostedtoolcache/Python/${{ steps.setup-python.outputs.python-version }}/x64/lib/python3.9/site-packages/heliport/*.bin
/opt/hostedtoolcache/Python/${{ steps.setup-python.outputs.python-version }}/x64/lib/python3.9/site-packages/heliport/confidenceThresholds

sdist:
runs-on: ubuntu-22.04
Expand Down Expand Up @@ -97,12 +86,6 @@ jobs:
- uses: actions/download-artifact@v4
- name: Display downloaded files
run: ls -R
- name: Archive model files
run: for i in models-*; do tar czvf $i.tgz $i/*; done
- name: Upload Github release assets
uses: softprops/action-gh-release@v2
with:
files: models-*.tgz
- name: Publish to PyPI
uses: PyO3/maturin-action@v1
with:
Expand Down
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Ignore binary model files
*.ser
*.bin
# but keep heliport.data dir so maturin picks it up when doing a clean build

# Wheels
wheels*
Expand Down
15 changes: 11 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,24 @@ name = "heliport"
# use cdylib to enable maturin linkage
crate-type = ["lib", "cdylib"]

[workspace]
members = ["heliport-model"]

[build-dependencies]
heliport-model = { path = "heliport-model" }
anyhow = "1.0"
log = { version = "0.4" }
strum = { version = "0.25", features = ["derive"] }

[dependencies]
bitcode = "0.6"
heliport-model = { path = "heliport-model" }
regex = "1.10"
unicode-blocks = "0.1.8"
shingles = "0.1"
ordered-float = "4.2"
log = { version = "0.4" }
env_logger = "0.10"
strum = { version = "0.25", features = ["derive"] }
strum_macros = "0.25"
wyhash2 = "0.2.1"
pyo3 = { version = "0.22", features = ["gil-refs", "anyhow"], optional = true }
target = { version = "2.1.0", optional = true }
tempfile = { version = "3", optional = true }
Expand All @@ -40,6 +47,6 @@ test-log = "0.2.15"
[features]
# Put log features in default, to allow crates using heli as a library, disable them
default = ["cli", "log/max_level_debug", "log/release_max_level_debug"]
cli = ["download", "python", "dep:clap", "dep:target"]
cli = ["python", "dep:clap", "dep:target"]
download = ["dep:tokio", "dep:tempfile", "dep:reqwest", "dep:futures-util"]
python = ["dep:pyo3"]
50 changes: 50 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use std::env;
use std::fs;
use std::path::PathBuf;

use anyhow::Result;
use log::{info};
use strum::IntoEnumIterator;

use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram};

fn main() -> Result<(), std::io::Error> {
let mut model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
model_path.push("LanguageModels");

let platlib_path = concat!(
env!("CARGO_MANIFEST_DIR"),
"/heliport.data/platlib/heliport",
);
fs::create_dir_all(&platlib_path)?;
let save_path = PathBuf::from(&platlib_path);

// Re-run build script if language models has changed
println!(
"cargo:rerun-if-changed={}",
model_path.display()
);
println!("cargo:rerun-if-changed=build.rs");

//TODO parallelize
for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = ModelNgram::from_text(&model_path, model_type, None)
.unwrap();
let size = model.dic.len();
info!("Created {size} entries");
let filename = save_path.join(format!("{type_repr}.bin"));
info!("Saving {type_repr} model");
model.save(&filename).unwrap();
}
info!("Copying confidence thresholds file");
fs::copy(
model_path.join(Model::CONFIDENCE_FILE),
save_path.join(Model::CONFIDENCE_FILE),
).unwrap();

info!("Finished");

Ok(())
}
12 changes: 12 additions & 0 deletions heliport-model/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "heliport-model"
version = "0.1.0"
edition = "2021"

[dependencies]
bitcode = "0.6"
log = { version = "0.4" }
strum = { version = "0.25", features = ["derive"] }
strum_macros = "0.25"
wyhash2 = "0.2.1"
anyhow = "1.0"
File renamed without changes.
File renamed without changes.
2 changes: 2 additions & 0 deletions heliport-model/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod languagemodel;
pub mod lang;
2 changes: 2 additions & 0 deletions heliport.data/platlib/heliport/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@ heliport-download = "heliport:cli_download"

[tool.maturin]
features = ["pyo3/extension-module"]
data = "heliport.data"
exclude = ["heliport.data/.gitkeep"]
66 changes: 66 additions & 0 deletions src/cli/binarize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use std::path::{Path, PathBuf};
use std::fs;
use std::process::exit;

use clap::Args;
use log::{error, warn, info};
use pyo3::prelude::*;
use strum::IntoEnumIterator;

use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram};
use crate::utils::Abort;
use crate::python::module_path;

#[derive(Args, Clone)]
pub struct BinarizeCmd {
#[arg(help="Input directory where ngram frequency files are located")]
input_dir: Option<PathBuf>,
#[arg(help="Output directory to place the binary files")]
output_dir: Option<PathBuf>,
#[arg(short, long, help="Force overwrite of output files if they already exist")]
force: bool,
}

impl BinarizeCmd {
pub fn cli(self) -> PyResult<()> {
let model_path = self.input_dir.unwrap_or(PathBuf::from("./LanguageModels"));
let save_path = self.output_dir.unwrap_or(module_path().unwrap());

// Fail and warn the use if there is already a model
if !self.force &&
save_path.join(
format!("{}.bin", OrderNgram::Word.to_string())
).exists()
{
warn!("Binarized models are now included in the PyPi package, \
there is no need to binarize the model unless you are training a new one"
);
error!("Output model already exists, use '-f' to force overwrite");
exit(1);
}

for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = ModelNgram::from_text(&model_path, model_type, None)
.or_abort(1);
let size = model.dic.len();
info!("Created {size} entries");
let filename = save_path.join(format!("{type_repr}.bin"));
info!("Saving {type_repr} model");
model.save(Path::new(&filename)).or_abort(1);
}
info!("Copying confidence thresholds file");
fs::copy(
model_path.join(Model::CONFIDENCE_FILE),
save_path.join(Model::CONFIDENCE_FILE),
).or_abort(1);

info!("Saved models at '{}'", save_path.display());
info!("Finished");

Ok(())
}
}


34 changes: 34 additions & 0 deletions src/cli/download.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::path::PathBuf;
use std::env;

use clap::Args;
use pyo3::prelude::*;
use log::info;
use target;

use crate::python::module_path;
use crate::download;

#[derive(Args, Clone)]
pub struct DownloadCmd {
#[arg(help="Path to download the model, defaults to the module path")]
path: Option<PathBuf>,
}

impl DownloadCmd {
pub fn cli(self) -> PyResult<()> {
let download_path = self.path.unwrap_or(module_path().unwrap());

let url = format!(
"https://github.com/ZJaume/{}/releases/download/v{}/models-{}-{}.tgz",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION"),
target::os(),
target::arch());

download::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap();
info!("Finished");

Ok(())
}
}
109 changes: 6 additions & 103 deletions src/cli.rs → src/cli/identify.rs
Original file line number Diff line number Diff line change
@@ -1,105 +1,21 @@
use std::io::{self, BufRead, BufReader, Write, BufWriter};
use std::fs::{copy, File};
use std::fs::File;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::env;

use anyhow::{Context, Result};
use clap::{Parser, Subcommand, Args};
use clap::Args;
use itertools::Itertools;
use log::{debug};
use pyo3::prelude::*;
use log::{info, debug};
use env_logger::Env;
use strum::IntoEnumIterator;
use target;

use crate::languagemodel::{Model, ModelNgram, OrderNgram};
use crate::lang::Lang;
use heliport_model::lang::Lang;
use crate::identifier::Identifier;
use crate::utils::Abort;
use crate::python::module_path;
use crate::download;

#[derive(Parser, Clone)]
#[command(version, about, long_about = None)]
pub struct Cli {
#[command(subcommand)]
command: Commands,
}

#[derive(Subcommand, Clone)]
enum Commands {
#[command(about="Download heliport model from GitHub")]
Download(DownloadCmd),
#[command(about="Binarize heliport model")]
Binarize(BinarizeCmd),
#[command(about="Identify languages of input text", visible_alias="detect")]
Identify(IdentifyCmd),
}

#[derive(Args, Clone)]
struct BinarizeCmd {
#[arg(help="Input directory where ngram frequency files are located")]
input_dir: Option<PathBuf>,
#[arg(help="Output directory to place the binary files")]
output_dir: Option<PathBuf>,
}

impl BinarizeCmd {
fn cli(self) -> PyResult<()> {
let model_path = self.input_dir.unwrap_or(PathBuf::from("./LanguageModels"));
let save_path = self.output_dir.unwrap_or(module_path().unwrap());

for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = ModelNgram::from_text(&model_path, model_type, None)
.or_abort(1);
let size = model.dic.len();
info!("Created {size} entries");
let filename = save_path.join(format!("{type_repr}.bin"));
info!("Saving {type_repr} model");
model.save(Path::new(&filename)).or_abort(1);
}
info!("Copying confidence thresholds file");
copy(
model_path.join(Model::CONFIDENCE_FILE),
save_path.join(Model::CONFIDENCE_FILE),
).or_abort(1);

info!("Saved models at '{}'", save_path.display());
info!("Finished");

Ok(())
}
}

#[derive(Args, Clone)]
struct DownloadCmd {
#[arg(help="Path to download the model, defaults to the module path")]
path: Option<PathBuf>,
}

impl DownloadCmd {
fn cli(self) -> PyResult<()> {
let download_path = self.path.unwrap_or(module_path().unwrap());

let url = format!(
"https://github.com/ZJaume/{}/releases/download/v{}/models-{}-{}.tgz",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION"),
target::os(),
target::arch());

download::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap();
info!("Finished");

Ok(())
}
}

#[derive(Args, Clone, Debug)]
struct IdentifyCmd {
pub struct IdentifyCmd {
#[arg(help="Number of parallel threads to use.\n0 means no multi-threading\n1 means running the identification in a separated thread\n>1 run multithreading",
short='j',
long,
Expand Down Expand Up @@ -154,7 +70,7 @@ fn parse_langs(langs_text: &Vec<String>) -> Result<Vec<Lang>> {
}

impl IdentifyCmd {
fn cli(self) -> PyResult<()> {
pub fn cli(self) -> PyResult<()> {
// If provided, parse the list of relevant languages
let mut relevant_langs = None;
if let Some(r) = &self.relevant_langs {
Expand Down Expand Up @@ -264,17 +180,4 @@ impl IdentifyCmd {
}
}

#[pyfunction]
pub fn cli_run() -> PyResult<()> {
// parse the cli arguments, skip the first one that is the path to the Python entry point
let os_args = std::env::args_os().skip(1);
let args = Cli::parse_from(os_args);
debug!("Module path found at: {}", module_path().expect("Could not found module path").display());
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();

match args.command {
Commands::Download(cmd) => { cmd.cli() },
Commands::Binarize(cmd) => { cmd.cli() },
Commands::Identify(cmd) => { cmd.cli() },
}
}
Loading

0 comments on commit 12c77ee

Please sign in to comment.