diff --git a/Cargo.toml b/Cargo.toml index 05160120..05fec5a8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,11 +5,11 @@ edition = "2021" [lib] name = "outlines_core_rs" -crate-type = ["cdylib"] +crate-type = ["cdylib", "rlib"] [dependencies] anyhow = "1.0.86" -pyo3 = { version = "0.22.0", features = ["extension-module"] } +pyo3 = { version = "0.22.0", features = ["extension-module"], optional=true } regex = "1.10.6" serde-pyobject = "0.4.0" serde_json = { version ="1.0.125", features = ["preserve_order"] } @@ -20,3 +20,6 @@ lto = true codegen-units = 1 strip = true panic = 'abort' + +[features] +python-bindings = ["pyo3"] diff --git a/setup.py b/setup.py index 4c414e6b..d19e0715 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ "outlines_core.fsm.outlines_core_rs", f"{CURRENT_DIR}/Cargo.toml", binding=Binding.PyO3, + features=["python-bindings"], rustc_flags=["--crate-type=cdylib"], ), ] diff --git a/src/json_schema/types.rs b/src/json_schema/types.rs index 02c748f6..aff5e53b 100644 --- a/src/json_schema/types.rs +++ b/src/json_schema/types.rs @@ -53,6 +53,7 @@ impl FormatType { } } + #[allow(clippy::should_implement_trait)] pub fn from_str(s: &str) -> Option { match s { "date-time" => Some(FormatType::DateTime), diff --git a/src/lib.rs b/src/lib.rs index bd129f6d..0bc900d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,57 +1,5 @@ -mod json_schema; -mod regex; +pub mod json_schema; +pub mod regex; -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; -use pyo3::types::PyDict; -use pyo3::wrap_pyfunction; -use regex::_walk_fsm; -use regex::create_fsm_index_end_to_end; -use regex::get_token_transition_keys; -use regex::get_vocabulary_transition_keys; -use regex::state_scan_tokens; -use regex::FSMInfo; -use serde_json::Value; - -#[pymodule] -fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(_walk_fsm, m)?)?; - m.add_function(wrap_pyfunction!(state_scan_tokens, m)?)?; - m.add_function(wrap_pyfunction!(get_token_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys, m)?)?; - m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end, m)?)?; - - m.add_class::()?; - - m.add("BOOLEAN", json_schema::BOOLEAN)?; - m.add("DATE", json_schema::DATE)?; - m.add("DATE_TIME", json_schema::DATE_TIME)?; - m.add("INTEGER", json_schema::INTEGER)?; - m.add("NULL", json_schema::NULL)?; - m.add("NUMBER", json_schema::NUMBER)?; - m.add("STRING", json_schema::STRING)?; - m.add("STRING_INNER", json_schema::STRING_INNER)?; - m.add("TIME", json_schema::TIME)?; - m.add("UUID", json_schema::UUID)?; - m.add("WHITESPACE", json_schema::WHITESPACE)?; - - m.add_function(wrap_pyfunction!(build_regex_from_schema, m)?)?; - m.add_function(wrap_pyfunction!(to_regex, m)?)?; - - Ok(()) -} - -#[pyfunction(name = "build_regex_from_schema")] -#[pyo3(signature = (json, whitespace_pattern=None))] -pub fn build_regex_from_schema(json: String, whitespace_pattern: Option<&str>) -> PyResult { - json_schema::build_regex_from_schema(&json, whitespace_pattern) - .map_err(|e| PyValueError::new_err(e.to_string())) -} - -#[pyfunction(name = "to_regex")] -#[pyo3(signature = (json, whitespace_pattern=None))] -pub fn to_regex(json: Bound, whitespace_pattern: Option<&str>) -> PyResult { - let json_value: Value = serde_pyobject::from_pyobject(json).unwrap(); - json_schema::to_regex(&json_value, whitespace_pattern, &json_value) - .map_err(|e| PyValueError::new_err(e.to_string())) -} +#[cfg(feature = "python-bindings")] +mod python_bindings; diff --git a/src/python_bindings/mod.rs b/src/python_bindings/mod.rs new file mode 100644 index 00000000..22bebe62 --- /dev/null +++ b/src/python_bindings/mod.rs @@ -0,0 +1,221 @@ +use crate::json_schema; +use crate::regex::get_token_transition_keys; +use crate::regex::get_vocabulary_transition_keys; +use crate::regex::state_scan_tokens; +use crate::regex::walk_fsm; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::PyDict; +use pyo3::wrap_pyfunction; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; + +#[pyclass] +pub struct FSMInfo { + #[pyo3(get)] + initial: u32, + #[pyo3(get)] + finals: HashSet, + #[pyo3(get)] + transitions: HashMap<(u32, u32), u32>, + #[pyo3(get)] + alphabet_anything_value: u32, + #[pyo3(get)] + alphabet_symbol_mapping: HashMap, +} + +#[pymethods] +impl FSMInfo { + #[new] + fn new( + initial: u32, + finals: HashSet, + transitions: HashMap<(u32, u32), u32>, + alphabet_anything_value: u32, + alphabet_symbol_mapping: HashMap, + ) -> Self { + Self { + initial, + finals, + transitions, + alphabet_anything_value, + alphabet_symbol_mapping, + } + } +} + +#[pyfunction(name = "build_regex_from_schema")] +#[pyo3(signature = (json, whitespace_pattern=None))] +pub fn build_regex_from_schema_py( + json: String, + whitespace_pattern: Option<&str>, +) -> PyResult { + json_schema::build_regex_from_schema(&json, whitespace_pattern) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction(name = "to_regex")] +#[pyo3(signature = (json, whitespace_pattern=None))] +pub fn to_regex_py(json: Bound, whitespace_pattern: Option<&str>) -> PyResult { + let json_value: Value = serde_pyobject::from_pyobject(json).unwrap(); + json_schema::to_regex(&json_value, whitespace_pattern, &json_value) + .map_err(|e| PyValueError::new_err(e.to_string())) +} + +#[pyfunction(name = "_walk_fsm")] +#[pyo3( + text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" +)] +pub fn walk_fsm_py( + fsm_transitions: HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: HashSet, + token_transition_keys: Vec, + start_state: u32, + full_match: bool, +) -> PyResult> { + Ok(walk_fsm( + &fsm_transitions, + fsm_initial, + &fsm_finals, + &token_transition_keys, + start_state, + full_match, + )) +} + +#[pyfunction(name = "state_scan_tokens")] +#[pyo3( + text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" +)] +pub fn state_scan_tokens_py( + fsm_transitions: HashMap<(u32, u32), u32>, + fsm_initial: u32, + fsm_finals: HashSet, + vocabulary: Vec<(String, Vec)>, + vocabulary_transition_keys: Vec>, + start_state: u32, +) -> PyResult> { + Ok(state_scan_tokens( + &fsm_transitions, + fsm_initial, + &fsm_finals, + &vocabulary, + &vocabulary_transition_keys, + start_state, + )) +} + +#[pyfunction(name = "get_token_transition_keys")] +#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] +pub fn get_token_transition_keys_py( + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: u32, + token_str: String, +) -> PyResult> { + Ok(get_token_transition_keys( + &alphabet_symbol_mapping, + alphabet_anything_value, + &token_str, + )) +} + +#[pyfunction(name = "get_vocabulary_transition_keys")] +#[pyo3( + text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" +)] +pub fn get_vocabulary_transition_keys_py( + alphabet_symbol_mapping: HashMap, + alphabet_anything_value: u32, + vocabulary: Vec<(String, Vec)>, + frozen_tokens: HashSet, +) -> PyResult>> { + Ok(get_vocabulary_transition_keys( + &alphabet_symbol_mapping, + alphabet_anything_value, + &vocabulary, + &frozen_tokens, + )) +} + +#[pyfunction(name = "create_fsm_index_end_to_end")] +#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] +pub fn create_fsm_index_end_to_end_py<'py>( + py: Python<'py>, + fsm_info: &FSMInfo, + vocabulary: Vec<(String, Vec)>, + frozen_tokens: HashSet, +) -> PyResult> { + let states_to_token_subsets = PyDict::new_bound(py); + let mut seen: HashSet = HashSet::new(); + let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); + + let vocabulary_transition_keys = get_vocabulary_transition_keys( + &fsm_info.alphabet_symbol_mapping, + fsm_info.alphabet_anything_value, + &vocabulary, + &frozen_tokens, + ); + + while let Some(start_state) = next_states.iter().cloned().next() { + next_states.remove(&start_state); + + // TODO: Return Pydict directly at construction + let token_ids_end_states = state_scan_tokens( + &fsm_info.transitions, + fsm_info.initial, + &fsm_info.finals, + &vocabulary, + &vocabulary_transition_keys, + start_state, + ); + + for (token_id, end_state) in token_ids_end_states { + if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { + existing_dict.set_item(token_id, end_state).unwrap(); + } else { + let new_dict = PyDict::new_bound(py); + new_dict.set_item(token_id, end_state).unwrap(); + states_to_token_subsets + .set_item(start_state, new_dict) + .unwrap(); + } + + if !seen.contains(&end_state) { + next_states.insert(end_state); + } + } + + seen.insert(start_state); + } + + Ok(states_to_token_subsets) +} + +#[pymodule] +fn outlines_core_rs(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(walk_fsm_py, m)?)?; + m.add_function(wrap_pyfunction!(state_scan_tokens_py, m)?)?; + m.add_function(wrap_pyfunction!(get_token_transition_keys_py, m)?)?; + m.add_function(wrap_pyfunction!(get_vocabulary_transition_keys_py, m)?)?; + m.add_function(wrap_pyfunction!(create_fsm_index_end_to_end_py, m)?)?; + + m.add_class::()?; + + m.add("BOOLEAN", json_schema::BOOLEAN)?; + m.add("DATE", json_schema::DATE)?; + m.add("DATE_TIME", json_schema::DATE_TIME)?; + m.add("INTEGER", json_schema::INTEGER)?; + m.add("NULL", json_schema::NULL)?; + m.add("NUMBER", json_schema::NUMBER)?; + m.add("STRING", json_schema::STRING)?; + m.add("STRING_INNER", json_schema::STRING_INNER)?; + m.add("TIME", json_schema::TIME)?; + m.add("UUID", json_schema::UUID)?; + m.add("WHITESPACE", json_schema::WHITESPACE)?; + + m.add_function(wrap_pyfunction!(build_regex_from_schema_py, m)?)?; + m.add_function(wrap_pyfunction!(to_regex_py, m)?)?; + + Ok(()) +} diff --git a/src/regex.rs b/src/regex.rs index df7d36f6..1db920ac 100644 --- a/src/regex.rs +++ b/src/regex.rs @@ -1,9 +1,6 @@ -use pyo3::prelude::*; - -use pyo3::types::PyDict; use std::collections::{HashMap, HashSet}; -pub fn walk_fsm_internal( +pub fn walk_fsm( fsm_transitions: &HashMap<(u32, u32), u32>, _fsm_initial: u32, fsm_finals: &HashSet, @@ -40,7 +37,7 @@ pub fn walk_fsm_internal( accepted_states } -pub fn state_scan_tokens_internal( +pub fn state_scan_tokens( fsm_transitions: &HashMap<(u32, u32), u32>, fsm_initial: u32, fsm_finals: &HashSet, @@ -55,7 +52,7 @@ pub fn state_scan_tokens_internal( { let token_ids: Vec = vocab_item.1.clone(); - let state_seq = walk_fsm_internal( + let state_seq = walk_fsm( fsm_transitions, fsm_initial, fsm_finals, @@ -76,7 +73,7 @@ pub fn state_scan_tokens_internal( res } -pub fn get_token_transition_keys_internal( +pub fn get_token_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, token_str: &str, @@ -109,7 +106,7 @@ pub fn get_token_transition_keys_internal( token_transition_keys } -pub fn get_vocabulary_transition_keys_internal( +pub fn get_vocabulary_transition_keys( alphabet_symbol_mapping: &HashMap, alphabet_anything_value: u32, vocabulary: &[(String, Vec)], @@ -132,7 +129,7 @@ pub fn get_vocabulary_transition_keys_internal( .unwrap_or(&alphabet_anything_value), ) } else { - token_transition_keys = get_token_transition_keys_internal( + token_transition_keys = get_token_transition_keys( alphabet_symbol_mapping, alphabet_anything_value, &token_str, @@ -144,168 +141,3 @@ pub fn get_vocabulary_transition_keys_internal( vocab_transition_keys } - -#[pyclass] -pub struct FSMInfo { - #[pyo3(get)] - initial: u32, - #[pyo3(get)] - finals: HashSet, - #[pyo3(get)] - transitions: HashMap<(u32, u32), u32>, - #[pyo3(get)] - alphabet_anything_value: u32, - #[pyo3(get)] - alphabet_symbol_mapping: HashMap, -} - -#[pymethods] -impl FSMInfo { - #[new] - fn new( - initial: u32, - finals: HashSet, - transitions: HashMap<(u32, u32), u32>, - alphabet_anything_value: u32, - alphabet_symbol_mapping: HashMap, - ) -> Self { - Self { - initial, - finals, - transitions, - alphabet_anything_value, - alphabet_symbol_mapping, - } - } -} - -#[pyfunction(name = "_walk_fsm")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, token_transition_keys, start_state, full_match)" -)] -pub fn _walk_fsm( - fsm_transitions: HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: HashSet, - token_transition_keys: Vec, - start_state: u32, - full_match: bool, -) -> PyResult> { - Ok(walk_fsm_internal( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &token_transition_keys, - start_state, - full_match, - )) -} - -#[pyfunction(name = "state_scan_tokens")] -#[pyo3( - text_signature = "(fsm_transitions, fsm_initial, fsm_finals, vocabulary, vocabulary_transition_keys, start_state)" -)] -pub fn state_scan_tokens( - fsm_transitions: HashMap<(u32, u32), u32>, - fsm_initial: u32, - fsm_finals: HashSet, - vocabulary: Vec<(String, Vec)>, - vocabulary_transition_keys: Vec>, - start_state: u32, -) -> PyResult> { - Ok(state_scan_tokens_internal( - &fsm_transitions, - fsm_initial, - &fsm_finals, - &vocabulary, - &vocabulary_transition_keys, - start_state, - )) -} - -#[pyfunction(name = "get_token_transition_keys")] -#[pyo3(text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, token_str)")] -pub fn get_token_transition_keys( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: u32, - token_str: String, -) -> PyResult> { - Ok(get_token_transition_keys_internal( - &alphabet_symbol_mapping, - alphabet_anything_value, - &token_str, - )) -} - -#[pyfunction(name = "get_vocabulary_transition_keys")] -#[pyo3( - text_signature = "(alphabet_symbol_mapping, alphabet_anything_value, vocabulary, frozen_tokens)" -)] -pub fn get_vocabulary_transition_keys( - alphabet_symbol_mapping: HashMap, - alphabet_anything_value: u32, - vocabulary: Vec<(String, Vec)>, - frozen_tokens: HashSet, -) -> PyResult>> { - Ok(get_vocabulary_transition_keys_internal( - &alphabet_symbol_mapping, - alphabet_anything_value, - &vocabulary, - &frozen_tokens, - )) -} - -#[allow(clippy::too_many_arguments)] -#[pyfunction(name = "create_fsm_index_end_to_end")] -#[pyo3(text_signature = "(fsm_info, vocabulary, frozen_tokens)")] -pub fn create_fsm_index_end_to_end<'py>( - py: Python<'py>, - fsm_info: &FSMInfo, - vocabulary: Vec<(String, Vec)>, - frozen_tokens: HashSet, -) -> PyResult> { - let states_to_token_subsets = PyDict::new_bound(py); - let mut seen: HashSet = HashSet::new(); - let mut next_states: HashSet = HashSet::from_iter(vec![fsm_info.initial]); - - let vocabulary_transition_keys = get_vocabulary_transition_keys_internal( - &fsm_info.alphabet_symbol_mapping, - fsm_info.alphabet_anything_value, - &vocabulary, - &frozen_tokens, - ); - - while let Some(start_state) = next_states.iter().cloned().next() { - next_states.remove(&start_state); - - // TODO: Return Pydict directly at construction - let token_ids_end_states = state_scan_tokens_internal( - &fsm_info.transitions, - fsm_info.initial, - &fsm_info.finals, - &vocabulary, - &vocabulary_transition_keys, - start_state, - ); - - for (token_id, end_state) in token_ids_end_states { - if let Ok(Some(existing_dict)) = states_to_token_subsets.get_item(start_state) { - existing_dict.set_item(token_id, end_state).unwrap(); - } else { - let new_dict = PyDict::new_bound(py); - new_dict.set_item(token_id, end_state).unwrap(); - states_to_token_subsets - .set_item(start_state, new_dict) - .unwrap(); - } - - if !seen.contains(&end_state) { - next_states.insert(end_state); - } - } - - seen.insert(start_state); - } - - Ok(states_to_token_subsets) -}