diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index e07bd55..9103597 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -49,9 +49,10 @@ jobs: ;; esac target: ${{ matrix.platform.target }} - args: --release --out dist --find-interpreter + args: --release --out dist --find-interpreter -vv sccache: 'true' manylinux: auto + - run: ls -R - name: Upload wheels uses: actions/upload-artifact@v4 with: diff --git a/Cargo.toml b/Cargo.toml index 0bc14be..a69c62b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,6 @@ test-log = "0.2.15" [features] # Put log features in default, to allow crates using heli as a library, disable them default = ["cli", "log/max_level_debug", "log/release_max_level_debug"] -cli = ["download", "python", "dep:clap", "dep:target"] +cli = ["python", "dep:clap", "dep:target"] download = ["dep:tokio", "dep:tempfile", "dep:reqwest", "dep:futures-util"] python = ["dep:pyo3"] diff --git a/src/cli/binarize.rs b/src/cli/binarize.rs new file mode 100644 index 0000000..bca0bf4 --- /dev/null +++ b/src/cli/binarize.rs @@ -0,0 +1,66 @@ +use std::path::{Path, PathBuf}; +use std::fs; +use std::process::exit; + +use clap::Args; +use log::{error, warn, info}; +use pyo3::prelude::*; +use strum::IntoEnumIterator; + +use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram}; +use crate::utils::Abort; +use crate::python::module_path; + +#[derive(Args, Clone)] +pub struct BinarizeCmd { + #[arg(help="Input directory where ngram frequency files are located")] + input_dir: Option, + #[arg(help="Output directory to place the binary files")] + output_dir: Option, + #[arg(short, long, help="Force overwrite of output files if they already exist")] + force: bool, +} + +impl BinarizeCmd { + pub fn cli(self) -> PyResult<()> { + let model_path = self.input_dir.unwrap_or(PathBuf::from("./LanguageModels")); + let save_path = self.output_dir.unwrap_or(module_path().unwrap()); + + // Fail and warn the use if there is already a model + if !self.force && + save_path.join( + format!("{}.bin", OrderNgram::Word.to_string()) + ).exists() + { + warn!("Binarized models are now included in the PyPi package, \ + there is no need to binarize the model unless you are training a new one" + ); + error!("Output model already exists, use '-f' to force overwrite"); + exit(1); + } + + for model_type in OrderNgram::iter() { + let type_repr = model_type.to_string(); + info!("Loading {type_repr} model"); + let model = ModelNgram::from_text(&model_path, model_type, None) + .or_abort(1); + let size = model.dic.len(); + info!("Created {size} entries"); + let filename = save_path.join(format!("{type_repr}.bin")); + info!("Saving {type_repr} model"); + model.save(Path::new(&filename)).or_abort(1); + } + info!("Copying confidence thresholds file"); + fs::copy( + model_path.join(Model::CONFIDENCE_FILE), + save_path.join(Model::CONFIDENCE_FILE), + ).or_abort(1); + + info!("Saved models at '{}'", save_path.display()); + info!("Finished"); + + Ok(()) + } +} + + diff --git a/src/cli/download.rs b/src/cli/download.rs new file mode 100644 index 0000000..500e542 --- /dev/null +++ b/src/cli/download.rs @@ -0,0 +1,34 @@ +use std::path::PathBuf; +use std::env; + +use clap::Args; +use pyo3::prelude::*; +use log::info; +use target; + +use crate::python::module_path; +use crate::download; + +#[derive(Args, Clone)] +pub struct DownloadCmd { + #[arg(help="Path to download the model, defaults to the module path")] + path: Option, +} + +impl DownloadCmd { + pub fn cli(self) -> PyResult<()> { + let download_path = self.path.unwrap_or(module_path().unwrap()); + + let url = format!( + "https://github.com/ZJaume/{}/releases/download/v{}/models-{}-{}.tgz", + env!("CARGO_PKG_NAME"), + env!("CARGO_PKG_VERSION"), + target::os(), + target::arch()); + + download::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap(); + info!("Finished"); + + Ok(()) + } +} diff --git a/src/cli.rs b/src/cli/identify.rs similarity index 60% rename from src/cli.rs rename to src/cli/identify.rs index f1419a3..0f2cd03 100644 --- a/src/cli.rs +++ b/src/cli/identify.rs @@ -1,121 +1,21 @@ use std::io::{self, BufRead, BufReader, Write, BufWriter}; -use std::fs::{copy, File}; +use std::fs::File; use std::path::{Path, PathBuf}; use std::str::FromStr; -use std::env; -use std::process::exit; use anyhow::{Context, Result}; -use clap::{Parser, Subcommand, Args}; +use clap::Args; use itertools::Itertools; +use log::{debug}; use pyo3::prelude::*; -use log::{error, warn, info, debug}; -use env_logger::Env; -use strum::IntoEnumIterator; -use target; -use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram}; use heliport_model::lang::Lang; use crate::identifier::Identifier; use crate::utils::Abort; use crate::python::module_path; -use crate::download; - -#[derive(Parser, Clone)] -#[command(version, about, long_about = None)] -pub struct Cli { - #[command(subcommand)] - command: Commands, -} - -#[derive(Subcommand, Clone)] -enum Commands { - #[command(about="Download heliport model from GitHub")] - Download(DownloadCmd), - #[command(about="Binarize heliport model")] - Binarize(BinarizeCmd), - #[command(about="Identify languages of input text", visible_alias="detect")] - Identify(IdentifyCmd), -} - -#[derive(Args, Clone)] -struct BinarizeCmd { - #[arg(help="Input directory where ngram frequency files are located")] - input_dir: Option, - #[arg(help="Output directory to place the binary files")] - output_dir: Option, - #[arg(short, long, help="Force overwrite of output files if they already exist")] - force: bool, -} - -impl BinarizeCmd { - fn cli(self) -> PyResult<()> { - let model_path = self.input_dir.unwrap_or(PathBuf::from("./LanguageModels")); - let save_path = self.output_dir.unwrap_or(module_path().unwrap()); - - // Fail and warn the use if there is already a model - if !self.force && - save_path.join( - format!("{}.bin", OrderNgram::Word.to_string()) - ).exists() - { - warn!("Binarized models are now included in the PyPi package, \ - there is no need to binarize the model unless you are training a new one" - ); - error!("Output model already exists, use '-f' to force overwrite"); - exit(1); - } - - for model_type in OrderNgram::iter() { - let type_repr = model_type.to_string(); - info!("Loading {type_repr} model"); - let model = ModelNgram::from_text(&model_path, model_type, None) - .or_abort(1); - let size = model.dic.len(); - info!("Created {size} entries"); - let filename = save_path.join(format!("{type_repr}.bin")); - info!("Saving {type_repr} model"); - model.save(Path::new(&filename)).or_abort(1); - } - info!("Copying confidence thresholds file"); - copy( - model_path.join(Model::CONFIDENCE_FILE), - save_path.join(Model::CONFIDENCE_FILE), - ).or_abort(1); - - info!("Saved models at '{}'", save_path.display()); - info!("Finished"); - - Ok(()) - } -} - -#[derive(Args, Clone)] -struct DownloadCmd { - #[arg(help="Path to download the model, defaults to the module path")] - path: Option, -} - -impl DownloadCmd { - fn cli(self) -> PyResult<()> { - let download_path = self.path.unwrap_or(module_path().unwrap()); - - let url = format!( - "https://github.com/ZJaume/{}/releases/download/v{}/models-{}-{}.tgz", - env!("CARGO_PKG_NAME"), - env!("CARGO_PKG_VERSION"), - target::os(), - target::arch()); - - download::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap(); - info!("Finished"); - - Ok(()) - } -} #[derive(Args, Clone, Debug)] -struct IdentifyCmd { +pub struct IdentifyCmd { #[arg(help="Number of parallel threads to use.\n0 means no multi-threading\n1 means running the identification in a separated thread\n>1 run multithreading", short='j', long, @@ -170,7 +70,7 @@ fn parse_langs(langs_text: &Vec) -> Result> { } impl IdentifyCmd { - fn cli(self) -> PyResult<()> { + pub fn cli(self) -> PyResult<()> { // If provided, parse the list of relevant languages let mut relevant_langs = None; if let Some(r) = &self.relevant_langs { @@ -280,17 +180,4 @@ impl IdentifyCmd { } } -#[pyfunction] -pub fn cli_run() -> PyResult<()> { - // parse the cli arguments, skip the first one that is the path to the Python entry point - let os_args = std::env::args_os().skip(1); - let args = Cli::parse_from(os_args); - debug!("Module path found at: {}", module_path().expect("Could not found module path").display()); - env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); - match args.command { - Commands::Download(cmd) => { cmd.cli() }, - Commands::Binarize(cmd) => { cmd.cli() }, - Commands::Identify(cmd) => { cmd.cli() }, - } -} diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..c315c3c --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,52 @@ +mod identify; +#[cfg(feature = "download")] +mod download; +mod binarize; + +use clap::{Subcommand, Parser}; +use log::{debug}; +use pyo3::prelude::*; +use env_logger::Env; + +use crate::python::module_path; +#[cfg(feature = "download")] +use self::download::DownloadCmd; +use self::binarize::BinarizeCmd; +use self::identify::IdentifyCmd; + +#[derive(Parser, Clone)] +#[command(version, about, long_about = None)] +pub struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Clone)] +enum Commands { + #[cfg(feature = "download")] + #[command(about="Download heliport model from GitHub")] + #[cfg(feature = "download")] + Download(DownloadCmd), + #[command(about="Binarize heliport model")] + Binarize(BinarizeCmd), + #[command(about="Identify languages of input text", visible_alias="detect")] + Identify(IdentifyCmd), +} + + + +#[pyfunction] +pub fn cli_run() -> PyResult<()> { + // parse the cli arguments, skip the first one that is the path to the Python entry point + let os_args = std::env::args_os().skip(1); + let args = Cli::parse_from(os_args); + debug!("Module path found at: {}", module_path().expect("Could not found module path").display()); + env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); + + match args.command { + #[cfg(feature = "download")] + Commands::Download(cmd) => { cmd.cli() }, + Commands::Binarize(cmd) => { cmd.cli() }, + Commands::Identify(cmd) => { cmd.cli() }, + } +}