From f1a7510285212b4ae07e2c160ea4c0bfa494228f Mon Sep 17 00:00:00 2001 From: ZJaume Date: Wed, 13 Nov 2024 12:07:39 +0000 Subject: [PATCH] Create models command --- .gitignore | 4 +- CHANGELOG.md | 1 + Cargo.toml | 1 + src/cli/create_models.rs | 51 ++++++++++++++++++ src/cli/mod.rs | 5 ++ src/identifier.rs | 8 +-- src/lib.rs | 1 + src/trainer.rs | 114 +++++++++++++++++++++++++++++++++++++++ src/utils.rs | 10 +++- 9 files changed, 185 insertions(+), 10 deletions(-) create mode 100644 src/cli/create_models.rs create mode 100644 src/trainer.rs diff --git a/.gitignore b/.gitignore index 18531cc..1337c57 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -# Ignore binary model files -# but keep heliport.data dir so maturin picks it up when doing a clean build +# Training files +*.train # Wheels wheels* diff --git a/CHANGELOG.md b/CHANGELOG.md index 48941a9..89f755b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## v0.8.0 ### Added +- Model creation command. ### Changed - Include binarized model in the wheel. diff --git a/Cargo.toml b/Cargo.toml index 781f623..e288081 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ anyhow = "1.0" rayon = "1.10" itertools = "0.11" lazy_static = "1.5" +counter = "0.6.0" [dev-dependencies] test-log = "0.2.15" diff --git a/src/cli/create_models.rs b/src/cli/create_models.rs new file mode 100644 index 0000000..8124f6b --- /dev/null +++ b/src/cli/create_models.rs @@ -0,0 +1,51 @@ +use std::path::{PathBuf}; +use std::process::exit; +use std::time::Instant; + +use anyhow::Context; +use clap::Args; +use log::{info, error}; +use pyo3::prelude::*; +use rayon::prelude::*; + +use crate::utils::Abort; +use crate::trainer::count_all_ngrams; + +#[derive(Args, Clone)] +pub struct CreateModelCmd { + #[arg(help="Output directory to save the ngram frequency files")] + output_dir: PathBuf, + #[arg(help="Directory where input text files are located")] + input_files: Vec, + #[arg(short = 'k', long, default_value_t = 10000, help="Truncate at top-k most frequent n-grams")] + topk: usize, +} + +impl CreateModelCmd { + pub fn cli(self) -> PyResult<()> { + info!("Starting"); + let now = Instant::now(); + + if !self.output_dir.exists() { + error!("Output directory '{}' does not exist, please create it", self.output_dir.display()); + exit(1); + } + + info!("Saving top {} most frequent n-grams", self.topk); + + // Train each file/language in parallel + // use panic_fuse to fail early if one of the jobs fail + self.input_files + .into_par_iter() + .panic_fuse() + .for_each(|lang_file| { + count_all_ngrams(&lang_file, &self.output_dir, self.topk) + .with_context(|| format!("Error with file '{}'", lang_file.display())) + .or_abort(1); + }); + + info!("Finished"); + info!("Elapsed time: {:.2?}", now.elapsed()); + Ok(()) + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index fc222f8..7077c2a 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -2,6 +2,7 @@ mod identify; #[cfg(feature = "download")] mod download; mod binarize; +mod create_models; use clap::{Subcommand, Parser}; use log::{debug}; @@ -13,6 +14,7 @@ use crate::python::module_path; use self::download::DownloadCmd; use self::binarize::BinarizeCmd; use self::identify::IdentifyCmd; +use self::create_models::CreateModelCmd; #[derive(Parser, Clone)] #[command(version, about, long_about = None)] @@ -33,6 +35,8 @@ enum Commands { Binarize(BinarizeCmd), #[command(about="Identify languages of input text", visible_alias="detect")] Identify(IdentifyCmd), + #[command(about="Create heliport models")] + CreateModel(CreateModelCmd), } @@ -54,5 +58,6 @@ pub fn cli_run() -> PyResult<()> { Commands::Download(cmd) => { cmd.cli() }, Commands::Binarize(cmd) => { cmd.cli() }, Commands::Identify(cmd) => { cmd.cli() }, + Commands::CreateModel(cmd) => { cmd.cli() }, } } diff --git a/src/identifier.rs b/src/identifier.rs index c790239..83350ca 100644 --- a/src/identifier.rs +++ b/src/identifier.rs @@ -5,10 +5,8 @@ use std::sync::{Arc, Mutex}; use ordered_float::OrderedFloat; use strum::{IntoEnumIterator, EnumCount}; use shingles::AsShingles; -use regex::Regex; use anyhow::Result; use log::{debug,warn}; -use lazy_static::lazy_static; use rayon::prelude::*; #[cfg(feature = "python")] @@ -16,12 +14,8 @@ use pyo3::pyclass; use heliport_model::Model; use heliport_model::{Lang, LangScores, LangBitmap}; -use crate::utils::is_cjk_block; +use crate::utils::{is_cjk_block, RE_NON_ALPHA}; -lazy_static! { - static ref RE_NON_ALPHA: Regex = Regex::new(r#"[^#gc\p{L}\p{M}′'’´ʹािीुूृेैोौंँः् া ি ী ু ূ ৃ ে ৈ ো ৌ।্্্я̄\u07A6\u07A7\u07A8\u07A9\u07AA\u07AB\u07AC\u07AD\u07AE\u07AF\u07B0\u0A81\u0A82\u0A83\u0ABC\u0ABD\u0ABE\u0ABF\u0AC0\u0AC1\u0AC2\u0AC3\u0AC4\u0AC5\u0AC6\u0AC7\u0AC8\u0AC9\u0ACA\u0ACB\u0ACC\u0ACD\u0AD0\u0AE0\u0AE1\u0AE2\u0AE3\u0AE4\u0AE5\u0AE6\u0AE7\u0AE8\u0AE9\u0AEA\u0AEB\u0AEC\u0AED\u0AEE\u0AEF\u0AF0\u0AF1]"#) - .expect("Error compiling non-alpha regex for Idenfifier"); -} #[cfg_attr(feature = "python", pyclass)] pub struct Identifier { diff --git a/src/lib.rs b/src/lib.rs index d1247b0..e13063c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,3 +6,4 @@ pub mod utils; mod cli; #[cfg(feature = "python")] mod python; +pub mod trainer; diff --git a/src/trainer.rs b/src/trainer.rs new file mode 100644 index 0000000..9d2a851 --- /dev/null +++ b/src/trainer.rs @@ -0,0 +1,114 @@ +use std::fs::File; +use std::io::{BufRead, BufReader, Write, BufWriter}; +use std::path::Path; + +use anyhow::{Context, Result}; +use counter::Counter; +use lazy_static::lazy_static; +use log::{info, debug}; +use rayon::prelude::*; +use regex::Regex; +use shingles::AsShingles; +use strum::IntoEnumIterator; + +use crate::utils::RE_NON_ALPHA; + +use heliport_model::{OrderNgram}; + + +lazy_static! { + static ref RE_LANG_NAME: Regex = Regex::new(r"(\w{3,7}).train$") + .expect("Error compiling lang name from file regex"); +} + +// Count n-gram frequency of a given n-gram order in the text contained in the file +fn count_ngrams(input_file_path: &Path, order: OrderNgram) -> Result> { + let input_file = BufReader::new(File::open(input_file_path)?); + let mut counts = Counter::new(); + + // Read training file line by line and accumulate ngram counts + for line_res in input_file.lines() { + let line = line_res?; + // Replace punctuation by spaces + let replaced = RE_NON_ALPHA.replace_all(&line, " "); + + // iterate over words + for word in replaced.split_whitespace() { + // if current order is word, just count the words + // otherwise put the space boundaries in the word + // and generate all possible ngrams of the current order + // and count them + if order == OrderNgram::Word { + if let Some(entry) = counts.get_mut(word) { + *entry += 1; + } else { + counts.insert(String::from(word), 1); + } + } else { + let wordspace = format!(" {word} "); + // order can be cast to integer because the internal representations + // have the same number (word is 0, unigram is 1 and so on) + for gram in wordspace.as_shingles(order as usize) { + if let Some(entry) = counts.get_mut(gram) { + *entry += 1; + } else { + counts.insert(String::from(gram), 1); + } + } + } + } + } + + Ok(counts) +} + +// Count n-gram frequency of all n-gram orders for a given lanuage +pub fn count_all_ngrams(input_file_path: &Path, output_dir: &Path, top_k: usize) -> Result<()> { + // use the lang prefix in the input file as language code + let string_file_name = input_file_path.to_string_lossy(); + let lang_string = RE_LANG_NAME + .captures(&string_file_name) + .context("Could not parse language name from input_file")? + .get(1) + .with_context(|| "Could not get first capture group from lang name regex")? + .as_str(); + // Check that the language exists + // avoid this for now, as it will require compile with a new lang before training + // let lang = Lang::from_str(&lang_string) + // .with_context(|| format!("Could not parse lang '{lang_string}'"))?; + info!("Training '{lang_string}'"); + + // Run training for each nggram order in parallel + let ngram_orders: Vec<_> = OrderNgram::iter().collect(); + let results: Vec> = ngram_orders + .into_par_iter() + .map(|order| -> Result<()> { + // Obtain nggram frequencies + let counts = count_ngrams(input_file_path, order)?; + // create output file with the language code and ngram order as name + let output_file = + File::create(output_dir.join(format!("{}.{}.model", lang_string, order.to_string()))) + .with_context(|| "Could not create file")?; + let mut output_file = BufWriter::new(output_file); + let total = counts.total::(); + debug!( + "Total: {} top-10: {:?}", + total, + counts.k_most_common_ordered(10) + ); + + // Write the top-k most frequent n-grams with their frequencies and the total count + writeln!(&mut output_file, "{}", total)?; + for (ngram, count) in counts.k_most_common_ordered(top_k) { + writeln!(&mut output_file, "{ngram}\t{count}")?; + } + Ok(()) + }).collect(); + + for r in results { + let _ = r?; + } + + info!("Finished '{lang_string}'"); + Ok(()) +} diff --git a/src/utils.rs b/src/utils.rs index 2dabaca..d58514d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,8 +1,15 @@ use std::process::exit; +use lazy_static::lazy_static; use log::error; +use regex::Regex; use unicode_blocks; +lazy_static! { + pub static ref RE_NON_ALPHA: Regex = Regex::new(r#"[^#gc\p{L}\p{M}′'’´ʹािीुूृेैोौंँः् া ি ী ু ূ ৃ ে ৈ ো ৌ।্্্я̄\u07A6\u07A7\u07A8\u07A9\u07AA\u07AB\u07AC\u07AD\u07AE\u07AF\u07B0\u0A81\u0A82\u0A83\u0ABC\u0ABD\u0ABE\u0ABF\u0AC0\u0AC1\u0AC2\u0AC3\u0AC4\u0AC5\u0AC6\u0AC7\u0AC8\u0AC9\u0ACA\u0ACB\u0ACC\u0ACD\u0AD0\u0AE0\u0AE1\u0AE2\u0AE3\u0AE4\u0AE5\u0AE6\u0AE7\u0AE8\u0AE9\u0AEA\u0AEB\u0AEC\u0AED\u0AEE\u0AEF\u0AF0\u0AF1]"#) + .expect("Error compiling non-alpha regex for Idenfifier"); +} + // Trait that extracts the contained ok value or aborts if error // sending the error message to the log pub trait Abort { @@ -14,7 +21,8 @@ impl Abort for Result fn or_abort(self, exit_code: i32) -> T { match self { Ok(v) => v, - Err(e) => { error!("{e}"); exit(exit_code); }, + // Print the whole error context with :# + Err(e) => { error!("{e:#}"); exit(exit_code); }, } } }