diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 66cfdab..696e60b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,3 +22,5 @@ jobs: - run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} - run: cargo build --verbose - run: cargo test --verbose + - run: cargo build --all-features --verbose + - run: cargo test --all-features --verbose diff --git a/.gitignore b/.gitignore index af3ca5e..9dc8179 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /target +Cargo.lock # Byte-compiled / optimized / DLL files __pycache__/ @@ -69,4 +70,4 @@ docs/_build/ .vscode/ # Pyenv -.python-version \ No newline at end of file +.python-version diff --git a/Cargo.lock b/Cargo.lock deleted file mode 100644 index 83ea4e3..0000000 --- a/Cargo.lock +++ /dev/null @@ -1,75 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "general-sam" -version = "0.3.0" -dependencies = [ - "rand", -] - -[[package]] -name = "getrandom" -version = "0.2.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "libc" -version = "0.2.149" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a08173bc88b7955d1b3145aa561539096c421ac8debde8cbc3612ec635fee29b" - -[[package]] -name = "ppv-lite86" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom", -] - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" diff --git a/Cargo.toml b/Cargo.toml index 77473f6..7345363 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,3 +17,7 @@ name = "general_sam" [dev-dependencies] rand = "0.8.5" + +[features] +trie = [] +all = ["trie"] diff --git a/src/lib.rs b/src/lib.rs index 8f91480..6f703bc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,6 +65,7 @@ //! ``` //! //! ```rust +//! # #[cfg(feature = "trie")] { //! use general_sam::{sam::GeneralSAM, trie::Trie}; //! //! let mut trie = Trie::default(); @@ -82,6 +83,7 @@ //! //! assert!(!sam.get_root_state().feed_chars("bye").is_accepting()); //! assert!(sam.get_root_state().feed_chars("bye").is_nil()); +//! # } //! ``` //! //! # References @@ -95,15 +97,19 @@ //! [general-sam-oi-wiki]: https://oi-wiki.org/string/general-sam/ pub mod sam; -pub mod trie; pub mod trie_alike; +#[cfg(feature = "trie")] +pub mod trie; + pub use sam::{ GeneralSAM, GeneralSAMNode, GeneralSAMNodeID, GeneralSAMState, SAM_NIL_NODE_ID, SAM_ROOT_NODE_ID, }; -pub use trie::{Trie, TrieNode, TrieNodeID, TrieState, TRIE_NIL_NODE_ID, TRIE_ROOT_NODE_ID}; pub use trie_alike::{IterAsChain, TravelEvent, TrieNodeAlike}; +#[cfg(feature = "trie")] +pub use trie::{Trie, TrieNode, TrieNodeID, TrieState, TRIE_NIL_NODE_ID, TRIE_ROOT_NODE_ID}; + #[cfg(test)] mod tests; diff --git a/src/sam/mod.rs b/src/sam/mod.rs index 881879a..d83a8cd 100644 --- a/src/sam/mod.rs +++ b/src/sam/mod.rs @@ -1,12 +1,12 @@ mod state; pub use state::GeneralSAMState; -use std::{ - collections::{BTreeMap, VecDeque}, - convert::Infallible, -}; +use std::{collections::BTreeMap, convert::Infallible}; -use crate::trie_alike::{IterAsChain, TravelEvent, TrieNodeAlike}; +use crate::{ + trie_alike::{IterAsChain, TrieNodeAlike}, + TravelEvent, +}; pub type GeneralSAMNodeID = usize; pub const SAM_NIL_NODE_ID: GeneralSAMNodeID = 0; @@ -142,22 +142,14 @@ impl GeneralSAM { where TN::InnerType: Into, { - let mut queue = VecDeque::new(); - let mut last_node_id = SAM_ROOT_NODE_ID; - node.bfs_travel(|event| -> Result<(), Infallible> { + node.bfs_travel(|event| -> Result { match event { - TravelEvent::Push(_, None) => { - queue.push_back(SAM_ROOT_NODE_ID); - } - TravelEvent::Pop(_) => { - last_node_id = queue.pop_front().unwrap(); + TravelEvent::PushRoot(_) => Ok(SAM_ROOT_NODE_ID), + TravelEvent::Push(cur_tn, cur_node_id, key) => { + Ok(self.insert_node_trans(*cur_node_id, key, cur_tn.is_accepting())) } - TravelEvent::Push(tn, Some(key)) => { - let new_node_id = self.insert_node_trans(last_node_id, key, tn.is_accepting()); - queue.push_back(new_node_id); - } - }; - Ok(()) + TravelEvent::Pop(_, cur_node_id) => Ok(cur_node_id), + } }) .unwrap(); } diff --git a/src/sam/state.rs b/src/sam/state.rs index d9fbaad..145d578 100644 --- a/src/sam/state.rs +++ b/src/sam/state.rs @@ -1,5 +1,3 @@ -use std::collections::VecDeque; - use crate::trie_alike::{TravelEvent, TrieNodeAlike}; use super::{GeneralSAM, GeneralSAMNode, SAM_NIL_NODE_ID, SAM_ROOT_NODE_ID}; @@ -16,13 +14,13 @@ impl<'s> GeneralSAMState<'s, u8> { } } -impl<'s> GeneralSAMState<'s, char> { - pub fn feed_chars(self, seq: &'s str) -> Self { +impl GeneralSAMState<'_, char> { + pub fn feed_chars(self, seq: &str) -> Self { self.feed(seq.chars()) } } -impl<'s, T: Ord + Clone> GeneralSAMState<'s, T> { +impl GeneralSAMState<'_, T> { pub fn is_nil(&self) -> bool { self.node_id == SAM_NIL_NODE_ID } @@ -58,116 +56,98 @@ impl<'s, T: Ord + Clone> GeneralSAMState<'s, T> { } } - pub fn feed_ref>(self, seq: Seq) -> Self { - self.feed_ref_iter(seq.into_iter()) + pub fn feed>(self, seq: Seq) -> Self { + self.feed_iter(seq.into_iter()) } - pub fn feed_ref_iter>(mut self, iter: Iter) -> Self { + pub fn feed_iter>(mut self, iter: Iter) -> Self { for t in iter { if self.is_nil() { break; } - self.goto(t) + self.goto(&t) } self } +} - pub fn feed>(self, seq: Seq) -> Self { - self.feed_iter(seq.into_iter()) +impl<'s, T: Ord + Clone> GeneralSAMState<'s, T> { + pub fn feed_ref>(self, seq: Seq) -> Self { + self.feed_ref_iter(seq.into_iter()) } - pub fn feed_iter>(mut self, iter: Iter) -> Self { + pub fn feed_ref_iter>(mut self, iter: Iter) -> Self { for t in iter { if self.is_nil() { break; } - self.goto(&t) + self.goto(t) } self } +} - pub fn bfs_along< - TN: TrieNodeAlike + Sized, - E, - F: FnMut(TravelEvent<(GeneralSAMState<'_, T>, &TN), TN::InnerType>) -> Result<(), E>, +impl<'s, T: Ord + Clone> GeneralSAMState<'s, T> { + fn wrap_travel_along_callback< + TN: TrieNodeAlike, + ExtraType, + ErrorType, + F: 's + + FnMut( + TravelEvent<(&GeneralSAMState, &TN), ExtraType, TN::InnerType>, + ) -> Result, >( - &self, - trie_node: TN, + &'s self, mut callback: F, - ) -> Result<(), E> { - let mut queue = VecDeque::new(); - let mut cur_node_id = self.node_id; - - trie_node.bfs_travel(|event| match event { - TravelEvent::Push(tn, Some(key)) => { - let next_node_id = self - .sam - .node_pool - .get(cur_node_id) - .and_then(|x| x.trans.get(&key).copied()) - .unwrap_or(SAM_NIL_NODE_ID); - callback(TravelEvent::Push( - (self.sam.get_state(next_node_id), tn), - Some(key), - ))?; - queue.push_back(next_node_id); - Ok(()) + ) -> impl FnMut( + TravelEvent<&TN, (GeneralSAMState<'s, T>, ExtraType), TN::InnerType>, + ) -> Result<(GeneralSAMState<'s, T>, ExtraType), ErrorType> { + move |event| match event { + TravelEvent::PushRoot(trie_root) => { + let res = callback(TravelEvent::PushRoot((self, trie_root)))?; + Ok((self.clone(), res)) } - TravelEvent::Push(tn, None) => { - callback(TravelEvent::Push( - (self.sam.get_state(self.node_id), tn), - None, - ))?; - queue.push_back(self.node_id); - Ok(()) + TravelEvent::Push(cur_tn, (cur_state, cur_extra), key) => { + let mut next_state = cur_state.clone(); + next_state.goto(&key); + let next_extra = + callback(TravelEvent::Push((&next_state, &cur_tn), cur_extra, key))?; + Ok((next_state, next_extra)) } - TravelEvent::Pop(tn) => { - cur_node_id = queue.pop_front().unwrap(); - callback(TravelEvent::Pop((self.sam.get_state(cur_node_id), tn)))?; - Ok(()) + TravelEvent::Pop(cur_tn, (cur_state, extra)) => { + let res = callback(TravelEvent::Pop((&cur_state, cur_tn), extra))?; + Ok((cur_state, res)) } - }) + } } pub fn dfs_along< TN: TrieNodeAlike + Clone, - E, - F: FnMut(TravelEvent<(GeneralSAMState<'_, T>, &TN), TN::InnerType>) -> Result<(), E>, + ExtraType, + ErrorType, + F: FnMut( + TravelEvent<(&GeneralSAMState<'_, T>, &TN), ErrorType, TN::InnerType>, + ) -> Result, >( &self, trie_node: TN, - mut callback: F, - ) -> Result<(), E> { - let mut stack: Vec = Vec::new(); - - trie_node.dfs_travel(|event| match event { - TravelEvent::Push(tn, Some(key)) => { - let next_node_id = self - .sam - .node_pool - .get(*stack.last().unwrap()) - .and_then(|x| x.trans.get(&key).copied()) - .unwrap_or(SAM_NIL_NODE_ID); - callback(TravelEvent::Push( - (self.sam.get_state(next_node_id), tn), - Some(key), - ))?; - stack.push(next_node_id); - Ok(()) - } - TravelEvent::Push(tn, None) => { - callback(TravelEvent::Push( - (self.sam.get_state(self.node_id), tn), - None, - ))?; - stack.push(self.node_id); - Ok(()) - } - TravelEvent::Pop(tn) => { - let node_id = stack.pop().unwrap(); - callback(TravelEvent::Pop((self.sam.get_state(node_id), tn)))?; - Ok(()) - } - }) + callback: F, + ) -> Result<(), ExtraType> { + trie_node.dfs_travel(self.wrap_travel_along_callback(callback)) + } + + pub fn bfs_along< + TN: TrieNodeAlike + Clone, + ExtraType, + ErrorType, + F: FnMut( + TravelEvent<(&GeneralSAMState<'_, T>, &TN), ErrorType, TN::InnerType>, + ) -> Result, + >( + &self, + trie_node: TN, + callback: F, + ) -> Result<(), ExtraType> { + trie_node.bfs_travel(self.wrap_travel_along_callback(callback)) } } diff --git a/src/tests.rs b/src/tests.rs index 8aa703d..d3d0829 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,10 +1,4 @@ -use rand::{ - distributions::{Alphanumeric, DistString}, - rngs::StdRng, - Rng, SeedableRng, -}; - -use crate::{sam::GeneralSAM, trie::Trie, SAM_ROOT_NODE_ID}; +use crate::sam::GeneralSAM; #[test] fn test_example_from_chars() { @@ -40,8 +34,11 @@ fn test_example_from_bytes() { assert!(!state.is_accepting() && state.is_nil() && !state.is_root()); } +#[cfg(feature = "trie")] #[test] fn test_example_from_trie() { + use crate::trie::Trie; + let mut trie = Trie::default(); trie.insert_iter("hello".chars()); @@ -120,7 +117,10 @@ fn test_chinese_chars() { println!("state \"你好\": {:?}", state.node_id); } +#[cfg(feature = "trie")] fn test_trie_suffix(vocab: &[&str]) { + use crate::trie::Trie; + let mut trie = Trie::default(); vocab.iter().for_each(|word| { trie.insert_iter(word.chars()); @@ -160,20 +160,32 @@ fn test_trie_suffix(vocab: &[&str]) { }); } +#[cfg(feature = "trie")] #[test] fn test_chiense_trie_suffix() { let vocab = ["歌曲", "聆听歌曲", "播放歌曲", "歌词", "查看歌词"]; test_trie_suffix(&vocab); } +#[cfg(feature = "trie")] #[test] fn test_simple_trie_suffix() { let vocab = ["ac", "bb", "b", "cc", "aabb", "a", "ba", "c", "aa"]; test_trie_suffix(&vocab); } +#[cfg(feature = "trie")] #[test] fn test_topo_and_suf_len_sorted_order() { + use rand::{ + distributions::{Alphanumeric, DistString}, + rngs::StdRng, + Rng, SeedableRng, + }; + + use crate::trie::Trie; + use crate::SAM_ROOT_NODE_ID; + let mut rng = StdRng::seed_from_u64(1134759173975); for _ in 0..10000 { let mut trie = Trie::default(); diff --git a/src/trie_alike.rs b/src/trie_alike.rs index dd4fec2..5f3cb97 100644 --- a/src/trie_alike.rs +++ b/src/trie_alike.rs @@ -1,8 +1,9 @@ use std::collections::VecDeque; -pub enum TravelEvent { - Push(NodeType, Option), - Pop(NodeType), +pub enum TravelEvent<'s, NodeType, ExtraType, KeyType> { + PushRoot(NodeType), + Push(NodeType, &'s ExtraType, KeyType), + Pop(NodeType, ExtraType), } /// This trait provides the essential interfaces required by `GeneralSAM` @@ -13,45 +14,59 @@ pub trait TrieNodeAlike { fn is_accepting(&self) -> bool; fn next_states(self) -> Self::NextStateIter; - fn bfs_travel) -> Result<(), E>>( + fn bfs_travel< + ErrorType, + ExtraType, + F: FnMut(TravelEvent<&Self, ExtraType, Self::InnerType>) -> Result, + >( self, mut callback: F, - ) -> Result<(), E> + ) -> Result<(), ErrorType> where Self: Sized, { let mut queue = VecDeque::new(); - callback(TravelEvent::Push(&self, None))?; - queue.push_back(self); - while let Some(state) = queue.pop_front() { - callback(TravelEvent::Pop(&state))?; + + let extra = callback(TravelEvent::PushRoot(&self))?; + queue.push_back((self, extra)); + + while let Some((state, cur_extra)) = queue.pop_front() { + let cur_extra = callback(TravelEvent::Pop(&state, cur_extra))?; + for (t, v) in state.next_states() { - callback(TravelEvent::Push(&v, Some(t)))?; - queue.push_back(v); + let next_extra = callback(TravelEvent::Push(&v, &cur_extra, t))?; + queue.push_back((v, next_extra)); } } Ok(()) } - fn dfs_travel) -> Result<(), E>>( + fn dfs_travel< + ErrorType, + ExtraType, + F: FnMut(TravelEvent<&Self, ExtraType, Self::InnerType>) -> Result, + >( self, mut callback: F, - ) -> Result<(), E> + ) -> Result<(), ErrorType> where Self: Clone, { let mut stack = Vec::new(); - callback(TravelEvent::Push(&self, None))?; - stack.push((self.clone(), self.next_states())); + let extra = callback(TravelEvent::PushRoot(&self))?; + stack.push((self.clone(), self.next_states(), extra)); - while let Some((ref cur, ref mut iter)) = stack.last_mut() { - if let Some((key, next_state)) = iter.next() { - callback(TravelEvent::Push(&next_state, Some(key)))?; - stack.push((next_state.clone(), next_state.next_states())); - } else { - callback(TravelEvent::Pop(cur))?; - stack.pop(); + while !stack.is_empty() { + if let Some((_, iter, extra)) = stack.last_mut() { + if let Some((key, next_state)) = iter.next() { + let new_extra = callback(TravelEvent::Push(&next_state, extra, key))?; + stack.push((next_state.clone(), next_state.next_states(), new_extra)); + continue; + } + } + if let Some((cur, _, extra)) = stack.pop() { + callback(TravelEvent::Pop(&cur, extra))?; } } Ok(())