Skip to content

Commit

Permalink
Create models command
Browse files Browse the repository at this point in the history
  • Loading branch information
ZJaume committed Nov 18, 2024
1 parent cb4a373 commit f1a7510
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 10 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Ignore binary model files
# but keep heliport.data dir so maturin picks it up when doing a clean build
# Training files
*.train

# Wheels
wheels*
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## v0.8.0
### Added
- Model creation command.

### Changed
- Include binarized model in the wheel.
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ anyhow = "1.0"
rayon = "1.10"
itertools = "0.11"
lazy_static = "1.5"
counter = "0.6.0"

[dev-dependencies]
test-log = "0.2.15"
Expand Down
51 changes: 51 additions & 0 deletions src/cli/create_models.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use std::path::{PathBuf};
use std::process::exit;
use std::time::Instant;

use anyhow::Context;
use clap::Args;
use log::{info, error};
use pyo3::prelude::*;
use rayon::prelude::*;

use crate::utils::Abort;
use crate::trainer::count_all_ngrams;

#[derive(Args, Clone)]
pub struct CreateModelCmd {
#[arg(help="Output directory to save the ngram frequency files")]
output_dir: PathBuf,
#[arg(help="Directory where input text files are located")]
input_files: Vec<PathBuf>,
#[arg(short = 'k', long, default_value_t = 10000, help="Truncate at top-k most frequent n-grams")]
topk: usize,
}

impl CreateModelCmd {
pub fn cli(self) -> PyResult<()> {
info!("Starting");
let now = Instant::now();

if !self.output_dir.exists() {
error!("Output directory '{}' does not exist, please create it", self.output_dir.display());
exit(1);
}

info!("Saving top {} most frequent n-grams", self.topk);

// Train each file/language in parallel
// use panic_fuse to fail early if one of the jobs fail
self.input_files
.into_par_iter()
.panic_fuse()
.for_each(|lang_file| {
count_all_ngrams(&lang_file, &self.output_dir, self.topk)
.with_context(|| format!("Error with file '{}'", lang_file.display()))
.or_abort(1);
});

info!("Finished");
info!("Elapsed time: {:.2?}", now.elapsed());
Ok(())
}
}
5 changes: 5 additions & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod identify;
#[cfg(feature = "download")]
mod download;
mod binarize;
mod create_models;

use clap::{Subcommand, Parser};
use log::{debug};
Expand All @@ -13,6 +14,7 @@ use crate::python::module_path;
use self::download::DownloadCmd;
use self::binarize::BinarizeCmd;
use self::identify::IdentifyCmd;
use self::create_models::CreateModelCmd;

#[derive(Parser, Clone)]
#[command(version, about, long_about = None)]
Expand All @@ -33,6 +35,8 @@ enum Commands {
Binarize(BinarizeCmd),
#[command(about="Identify languages of input text", visible_alias="detect")]
Identify(IdentifyCmd),
#[command(about="Create heliport models")]
CreateModel(CreateModelCmd),
}


Expand All @@ -54,5 +58,6 @@ pub fn cli_run() -> PyResult<()> {
Commands::Download(cmd) => { cmd.cli() },
Commands::Binarize(cmd) => { cmd.cli() },
Commands::Identify(cmd) => { cmd.cli() },
Commands::CreateModel(cmd) => { cmd.cli() },
}
}
8 changes: 1 addition & 7 deletions src/identifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,17 @@ use std::sync::{Arc, Mutex};
use ordered_float::OrderedFloat;
use strum::{IntoEnumIterator, EnumCount};
use shingles::AsShingles;
use regex::Regex;
use anyhow::Result;
use log::{debug,warn};
use lazy_static::lazy_static;
use rayon::prelude::*;

#[cfg(feature = "python")]
use pyo3::pyclass;

use heliport_model::Model;
use heliport_model::{Lang, LangScores, LangBitmap};
use crate::utils::is_cjk_block;
use crate::utils::{is_cjk_block, RE_NON_ALPHA};

lazy_static! {
static ref RE_NON_ALPHA: Regex = Regex::new(r#"[^#gc\p{L}\p{M}′'’´ʹािीुूृेैोौंँः् া ি ী ু ূ ৃ ে ৈ ো ৌ।্্্я̄\u07A6\u07A7\u07A8\u07A9\u07AA\u07AB\u07AC\u07AD\u07AE\u07AF\u07B0\u0A81\u0A82\u0A83\u0ABC\u0ABD\u0ABE\u0ABF\u0AC0\u0AC1\u0AC2\u0AC3\u0AC4\u0AC5\u0AC6\u0AC7\u0AC8\u0AC9\u0ACA\u0ACB\u0ACC\u0ACD\u0AD0\u0AE0\u0AE1\u0AE2\u0AE3\u0AE4\u0AE5\u0AE6\u0AE7\u0AE8\u0AE9\u0AEA\u0AEB\u0AEC\u0AED\u0AEE\u0AEF\u0AF0\u0AF1]"#)
.expect("Error compiling non-alpha regex for Idenfifier");
}

#[cfg_attr(feature = "python", pyclass)]
pub struct Identifier {
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ pub mod utils;
mod cli;
#[cfg(feature = "python")]
mod python;
pub mod trainer;
114 changes: 114 additions & 0 deletions src/trainer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
use std::fs::File;
use std::io::{BufRead, BufReader, Write, BufWriter};
use std::path::Path;

use anyhow::{Context, Result};
use counter::Counter;
use lazy_static::lazy_static;
use log::{info, debug};
use rayon::prelude::*;
use regex::Regex;
use shingles::AsShingles;
use strum::IntoEnumIterator;

use crate::utils::RE_NON_ALPHA;

use heliport_model::{OrderNgram};


lazy_static! {
static ref RE_LANG_NAME: Regex = Regex::new(r"(\w{3,7}).train$")
.expect("Error compiling lang name from file regex");
}

// Count n-gram frequency of a given n-gram order in the text contained in the file
fn count_ngrams(input_file_path: &Path, order: OrderNgram) -> Result<Counter<String>> {
let input_file = BufReader::new(File::open(input_file_path)?);
let mut counts = Counter::new();

// Read training file line by line and accumulate ngram counts
for line_res in input_file.lines() {
let line = line_res?;
// Replace punctuation by spaces
let replaced = RE_NON_ALPHA.replace_all(&line, " ");

// iterate over words
for word in replaced.split_whitespace() {
// if current order is word, just count the words
// otherwise put the space boundaries in the word
// and generate all possible ngrams of the current order
// and count them
if order == OrderNgram::Word {
if let Some(entry) = counts.get_mut(word) {
*entry += 1;
} else {
counts.insert(String::from(word), 1);
}
} else {
let wordspace = format!(" {word} ");
// order can be cast to integer because the internal representations
// have the same number (word is 0, unigram is 1 and so on)
for gram in wordspace.as_shingles(order as usize) {
if let Some(entry) = counts.get_mut(gram) {
*entry += 1;
} else {
counts.insert(String::from(gram), 1);
}
}
}
}
}

Ok(counts)
}

// Count n-gram frequency of all n-gram orders for a given lanuage
pub fn count_all_ngrams(input_file_path: &Path, output_dir: &Path, top_k: usize) -> Result<()> {
// use the lang prefix in the input file as language code
let string_file_name = input_file_path.to_string_lossy();
let lang_string = RE_LANG_NAME
.captures(&string_file_name)
.context("Could not parse language name from input_file")?
.get(1)
.with_context(|| "Could not get first capture group from lang name regex")?
.as_str();
// Check that the language exists
// avoid this for now, as it will require compile with a new lang before training
// let lang = Lang::from_str(&lang_string)
// .with_context(|| format!("Could not parse lang '{lang_string}'"))?;
info!("Training '{lang_string}'");

// Run training for each nggram order in parallel
let ngram_orders: Vec<_> = OrderNgram::iter().collect();
let results: Vec<Result<_>> = ngram_orders
.into_par_iter()
.map(|order| -> Result<()> {
// Obtain nggram frequencies
let counts = count_ngrams(input_file_path, order)?;
// create output file with the language code and ngram order as name
let output_file =
File::create(output_dir.join(format!("{}.{}.model", lang_string, order.to_string())))
.with_context(|| "Could not create file")?;
let mut output_file = BufWriter::new(output_file);
let total = counts.total::<usize>();
debug!(
"Total: {} top-10: {:?}",
total,
counts.k_most_common_ordered(10)
);

// Write the top-k most frequent n-grams with their frequencies and the total count
writeln!(&mut output_file, "{}", total)?;
for (ngram, count) in counts.k_most_common_ordered(top_k) {
writeln!(&mut output_file, "{ngram}\t{count}")?;
}
Ok(())
}).collect();

for r in results {
let _ = r?;
}

info!("Finished '{lang_string}'");
Ok(())
}
10 changes: 9 additions & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
use std::process::exit;

use lazy_static::lazy_static;
use log::error;
use regex::Regex;
use unicode_blocks;

lazy_static! {
pub static ref RE_NON_ALPHA: Regex = Regex::new(r#"[^#gc\p{L}\p{M}′'’´ʹािीुूृेैोौंँः् া ি ী ু ূ ৃ ে ৈ ো ৌ।্্্я̄\u07A6\u07A7\u07A8\u07A9\u07AA\u07AB\u07AC\u07AD\u07AE\u07AF\u07B0\u0A81\u0A82\u0A83\u0ABC\u0ABD\u0ABE\u0ABF\u0AC0\u0AC1\u0AC2\u0AC3\u0AC4\u0AC5\u0AC6\u0AC7\u0AC8\u0AC9\u0ACA\u0ACB\u0ACC\u0ACD\u0AD0\u0AE0\u0AE1\u0AE2\u0AE3\u0AE4\u0AE5\u0AE6\u0AE7\u0AE8\u0AE9\u0AEA\u0AEB\u0AEC\u0AED\u0AEE\u0AEF\u0AF0\u0AF1]"#)
.expect("Error compiling non-alpha regex for Idenfifier");
}

// Trait that extracts the contained ok value or aborts if error
// sending the error message to the log
pub trait Abort<T> {
Expand All @@ -14,7 +21,8 @@ impl<T, E: std::fmt::Display> Abort<T> for Result<T, E>
fn or_abort(self, exit_code: i32) -> T {
match self {
Ok(v) => v,
Err(e) => { error!("{e}"); exit(exit_code); },
// Print the whole error context with :#
Err(e) => { error!("{e:#}"); exit(exit_code); },
}
}
}
Expand Down

0 comments on commit f1a7510

Please sign in to comment.