diff --git a/src/_impl_bma_model.rs b/src/_impl_bma_model.rs index 49aaf96..c31c5f4 100644 --- a/src/_impl_bma_model.rs +++ b/src/_impl_bma_model.rs @@ -2,6 +2,8 @@ use crate::bma_model::*; use crate::enums::VariableType; use crate::json_model::JsonBmaModel; use crate::traits::{JsonSerde, XmlSerde}; +use crate::update_fn::bma_fn_tree::BmaFnNode; +use crate::update_fn::parser::parse_bma_formula; use crate::xml_model::XmlBmaModel; use biodivine_lib_param_bn::{BooleanNetwork, RegulatoryGraph}; use std::collections::HashMap; @@ -74,7 +76,8 @@ impl From for BmaModel { .unwrap_or(VariableType::Default), // Use the type from layout if available range_from: var.range_from, range_to: var.range_to, - formula: var.formula, + // todo: handle the failures and empty formulas + formula: parse_bma_formula(&var.formula).unwrap_or(BmaFnNode::mk_constant(0)), }) .collect(), relationships: json_model @@ -157,7 +160,8 @@ impl From for BmaModel { variable_type: var.r#type, range_from: var.range_from, range_to: var.range_to, - formula: var.formula, + // todo: handle the failures and empty formulas + formula: parse_bma_formula(&var.formula).unwrap_or(BmaFnNode::mk_constant(0)), }) .collect(), relationships: xml_model diff --git a/src/bma_model.rs b/src/bma_model.rs index 7620bdc..03b89e5 100644 --- a/src/bma_model.rs +++ b/src/bma_model.rs @@ -1,4 +1,5 @@ use crate::enums::{RelationshipType, VariableType}; +use crate::update_fn::bma_fn_tree::BmaFnNode; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -26,7 +27,7 @@ pub struct Variable { pub variable_type: VariableType, // Corresponds to "Type" in JSON/XML pub range_from: u32, pub range_to: u32, - pub formula: String, + pub formula: BmaFnNode, } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/src/update_fn/bma_fn_tree.rs b/src/update_fn/bma_fn_tree.rs index 8ebeddc..afebb32 100644 --- a/src/update_fn/bma_fn_tree.rs +++ b/src/update_fn/bma_fn_tree.rs @@ -1,5 +1,7 @@ use crate::update_fn::enums::{AggregateOp, ArithOp, Literal, UnaryOp}; +use crate::update_fn::parser::parse_bma_fn_tokens; use crate::update_fn::tokenizer::BmaFnToken; +use serde::{Deserialize, Serialize}; use std::cmp; use std::fmt; @@ -10,7 +12,7 @@ use std::fmt; /// - A "unary" node with a `UnaryOp` and a sub-expression. /// - A binary "arithmetic" node, with a `BinaryOp` and two sub-expressions. /// - An "aggregation" node with a `AggregateOp` op and a list of sub-expressions. -#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub enum Expression { Terminal(Literal), Unary(UnaryOp, Box), @@ -24,7 +26,7 @@ pub enum Expression { /// - `height`; A positive integer starting from 0 (for term nodes). /// - `expression_tree`; A parse tree for the expression`. /// - `function_str`; A canonical string representation of the expression. -#[derive(Clone, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)] pub struct BmaFnNode { pub function_str: String, pub height: u32, @@ -33,8 +35,8 @@ pub struct BmaFnNode { impl BmaFnNode { /// "Parse" a new [BmaFnNode] from a list of [BmaFnToken] objects. - pub fn from_tokens(_tokens: &[BmaFnToken]) -> Result { - todo!() + pub fn from_tokens(tokens: &[BmaFnToken]) -> Result { + parse_bma_fn_tokens(tokens) } /// Create a "unary" [BmaFnNode] from the given arguments. diff --git a/src/update_fn/mod.rs b/src/update_fn/mod.rs index 185e2ea..e79a884 100644 --- a/src/update_fn/mod.rs +++ b/src/update_fn/mod.rs @@ -1,4 +1,4 @@ pub mod bma_fn_tree; -mod enums; -mod parser; -mod tokenizer; +pub mod enums; +pub mod parser; +pub mod tokenizer; diff --git a/src/update_fn/parser.rs b/src/update_fn/parser.rs index 65b3dba..118a3c5 100644 --- a/src/update_fn/parser.rs +++ b/src/update_fn/parser.rs @@ -1 +1,268 @@ -// todo +use crate::update_fn::bma_fn_tree::*; +use crate::update_fn::enums::*; +use crate::update_fn::tokenizer::{try_tokenize_bma_formula, BmaFnToken}; + +/// Parse an BMA update function formula string representation into an actual expression tree. +/// Basically a wrapper for tokenize+parse (used often for testing/debug purposes). +/// +/// NEEDS to call [validate_props] to fully finish the preprocessing step. +pub fn parse_bma_formula(formula: &str) -> Result { + let tokens = try_tokenize_bma_formula(formula.to_string())?; + let tree = parse_bma_fn_tokens(&tokens)?; + Ok(tree) +} + +/// Utility method to find the first occurrence of a specific token in the token tree. +fn index_of_first(tokens: &[BmaFnToken], token: BmaFnToken) -> Option { + return tokens.iter().position(|t| *t == token); +} + +/// Parse `tokens` of BMA update fn formula into an abstract syntax tree using recursive steps. +pub fn parse_bma_fn_tokens(tokens: &[BmaFnToken]) -> Result { + parse_1_div(tokens) +} + +/// Recursive parsing step 1: extract `/` operators. +fn parse_1_div(tokens: &[BmaFnToken]) -> Result { + let div_token = index_of_first(tokens, BmaFnToken::Binary(ArithOp::Div)); + Ok(if let Some(i) = div_token { + BmaFnNode::mk_arithmetic( + parse_2_mul(&tokens[..i])?, + parse_1_div(&tokens[(i + 1)..])?, + ArithOp::Div, + ) + } else { + parse_2_mul(tokens)? + }) +} + +/// Recursive parsing step 2: extract `*` operators. +fn parse_2_mul(tokens: &[BmaFnToken]) -> Result { + let mul_token = index_of_first(tokens, BmaFnToken::Binary(ArithOp::Times)); + Ok(if let Some(i) = mul_token { + BmaFnNode::mk_arithmetic( + parse_3_minus(&tokens[..i])?, + parse_2_mul(&tokens[(i + 1)..])?, + ArithOp::Times, + ) + } else { + parse_3_minus(tokens)? + }) +} + +/// Recursive parsing step 3: extract `-` operators. +fn parse_3_minus(tokens: &[BmaFnToken]) -> Result { + let minus_token = index_of_first(tokens, BmaFnToken::Binary(ArithOp::Minus)); + Ok(if let Some(i) = minus_token { + BmaFnNode::mk_arithmetic( + parse_4_plus(&tokens[..i])?, + parse_3_minus(&tokens[(i + 1)..])?, + ArithOp::Minus, + ) + } else { + parse_4_plus(tokens)? + }) +} + +/// Recursive parsing step 4: extract `+` operators. +fn parse_4_plus(tokens: &[BmaFnToken]) -> Result { + let minus_token = index_of_first(tokens, BmaFnToken::Binary(ArithOp::Add)); + Ok(if let Some(i) = minus_token { + BmaFnNode::mk_arithmetic( + parse_5_others(&tokens[..i])?, + parse_4_plus(&tokens[(i + 1)..])?, + ArithOp::Add, + ) + } else { + parse_5_others(tokens)? + }) +} + +/// Recursive parsing step 5: extract literals and recursively solve sub-formulae in parentheses +/// and in functions. +fn parse_5_others(tokens: &[BmaFnToken]) -> Result { + if tokens.is_empty() { + Err("Expected formula, found nothing.".to_string()) + } else { + if tokens.len() == 1 { + // This should be name (var/function) or a parenthesis group, anything + // else does not make sense. + match &tokens[0] { + BmaFnToken::Atomic(Literal::Str(name)) => { + return Ok(BmaFnNode::mk_variable(name.as_str())); + } + BmaFnToken::Atomic(Literal::Int(num)) => { + return Ok(BmaFnNode::mk_constant(*num)); + } + BmaFnToken::Aggregate(operator, arguments) => { + let mut arg_expression_nodes = Vec::new(); + for inner in arguments { + // it must be a token list + if let BmaFnToken::TokenList(inner_token_list) = inner { + arg_expression_nodes.push(parse_bma_fn_tokens(inner_token_list)?); + } else { + return Err( + "Function must be applied on `BmaFnToken::TokenList` args." + .to_string(), + ); + } + } + return Ok(BmaFnNode::mk_aggregation( + operator.clone(), + arg_expression_nodes, + )); + } + BmaFnToken::Unary(operator, argument) => { + return if let BmaFnToken::TokenList(inner_token_list) = *argument.clone() { + Ok(BmaFnNode::mk_unary( + parse_bma_fn_tokens(&inner_token_list)?, + operator.clone(), + )) + } else { + return Err( + "Function must be applied on `BmaFnToken::TokenList` args.".to_string() + ); + } + } + // recursively solve sub-formulae in parentheses + BmaFnToken::TokenList(inner) => { + return parse_bma_fn_tokens(inner); + } + _ => {} // otherwise, fall through to the error at the end. + } + } + Err(format!("Unexpected: {tokens:?}. Expecting formula.")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::update_fn::bma_fn_tree::BmaFnNode; + use crate::update_fn::enums::{AggregateOp, ArithOp, UnaryOp}; + + #[test] + fn test_parse_simple_addition() { + let input = "3 + 5"; + let result = parse_bma_formula(input); + let expected = BmaFnNode::mk_arithmetic( + BmaFnNode::mk_constant(3), + BmaFnNode::mk_constant(5), + ArithOp::Add, + ); + assert_eq!(result, Ok(expected)); + } + + #[test] + fn test_parse_simple_subtraction() { + let input = "10 - 7"; + let result = parse_bma_formula(input); + let expected = BmaFnNode::mk_arithmetic( + BmaFnNode::mk_constant(10), + BmaFnNode::mk_constant(7), + ArithOp::Minus, + ); + assert_eq!(result, Ok(expected)); + } + + #[test] + fn test_parse_multiplication_and_division() { + let input = "8 * 4 / 2"; + let result = parse_bma_formula(input); + let expected = BmaFnNode::mk_arithmetic( + BmaFnNode::mk_arithmetic( + BmaFnNode::mk_constant(8), + BmaFnNode::mk_constant(4), + ArithOp::Times, + ), + BmaFnNode::mk_constant(2), + ArithOp::Div, + ); + assert_eq!(result, Ok(expected)); + } + + #[test] + fn test_parse_nested_arithmetic() { + let input = "3 + (5 * 2)"; + let result = parse_bma_formula(input); + let expected = BmaFnNode::mk_arithmetic( + BmaFnNode::mk_constant(3), + BmaFnNode::mk_arithmetic( + BmaFnNode::mk_constant(5), + BmaFnNode::mk_constant(2), + ArithOp::Times, + ), + ArithOp::Add, + ); + assert_eq!(result, Ok(expected)); + } + + #[test] + fn test_parse_abs_function() { + let input = "abs(5)"; + let result = parse_bma_formula(input); + let expected = BmaFnNode::mk_unary(BmaFnNode::mk_constant(5), UnaryOp::Abs); + assert_eq!(result, Ok(expected)); + } + + #[test] + fn test_parse_aggregate_min() { + let input = "min(3, 5, 5 + variable)"; + let result = parse_bma_formula(input); + let expected = BmaFnNode::mk_aggregation( + AggregateOp::Min, + vec![ + BmaFnNode::mk_constant(3), + BmaFnNode::mk_constant(5), + BmaFnNode::mk_arithmetic( + BmaFnNode::mk_constant(5), + BmaFnNode::mk_variable("variable"), + ArithOp::Add, + ), + ], + ); + assert_eq!(result, Ok(expected)); + } + + #[test] + fn test_parse_unmatched_parentheses() { + let input = "3 + (5 * 2"; + let result = parse_bma_formula(input); + assert!(result.is_err()); + assert_eq!( + result, + Err("Expected ')' to previously encountered opening counterpart.".to_string()) + ); + } + + #[test] + fn test_parse_invalid_token() { + let input = "5 + @"; + let result = parse_bma_formula(input); + assert!(result.is_err()); + assert_eq!(result, Err("Unexpected character: '@'".to_string())); + } + + #[test] + fn test_parse_function_with_multiple_arguments() { + let input = "max(3, 5, 10)"; + let result = parse_bma_formula(input); + let expected = BmaFnNode::mk_aggregation( + AggregateOp::Max, + vec![ + BmaFnNode::mk_constant(3), + BmaFnNode::mk_constant(5), + BmaFnNode::mk_constant(10), + ], + ); + assert_eq!(result, Ok(expected)); + } + + #[test] + fn test_parse_empty_formula() { + let input = ""; + let result = parse_bma_formula(input); + assert!(result.is_err()); + assert_eq!(result, Err("Expected formula, found nothing.".to_string())); + } +} diff --git a/src/update_fn/tokenizer.rs b/src/update_fn/tokenizer.rs index 6ca52b5..c084a90 100644 --- a/src/update_fn/tokenizer.rs +++ b/src/update_fn/tokenizer.rs @@ -1,13 +1,324 @@ use crate::update_fn::enums::{AggregateOp, ArithOp, Literal, UnaryOp}; +use std::iter::Peekable; +use std::str::Chars; /// Enum of all possible tokens occurring in a BMA function string. #[derive(Clone, Debug, Eq, Hash, PartialEq)] pub enum BmaFnToken { Atomic(Literal), - Unary(UnaryOp), + Unary(UnaryOp, Box), Binary(ArithOp), Aggregate(AggregateOp, Vec), TokenList(Vec), } -// todo +/// Tokenize a BMA formula string into tokens, +/// +/// This is a wrapper for the (more general) recursive [try_tokenize_recursive]` function. +pub fn try_tokenize_bma_formula(formula: String) -> Result, String> { + let (tokens, _) = try_tokenize_recursive(&mut formula.chars().peekable(), true, false)?; + Ok(tokens) +} + +/// Process a peekable iterator of characters into a vector of `BmaFnToken`s. This function is used +/// for both tokenizing a top-level expression and expressions that are fn's arguments. +/// +/// Returns a vector of (nested) tokens, and a last character. The last character is important when +/// we are parsing function arguments (to find out if another argument is expected or we already +/// processed the closing parenthesis). When parsing the top-level formula expression (not a function +/// argument), we simply return '$'. +/// +/// `top_fn_level` is used in case we are processing an expression passed as argument to some +/// function symbol (then ',' is valid delimiter). +fn try_tokenize_recursive( + input_chars: &mut Peekable, + top_level: bool, + top_fn_level: bool, +) -> Result<(Vec, char), String> { + let mut output = Vec::new(); + + while let Some(c) = input_chars.next() { + match c { + c if c.is_whitespace() => {} + '+' => output.push(BmaFnToken::Binary(ArithOp::Add)), + '-' => output.push(BmaFnToken::Binary(ArithOp::Minus)), + '*' => output.push(BmaFnToken::Binary(ArithOp::Times)), + '/' => output.push(BmaFnToken::Binary(ArithOp::Div)), + '(' => { + // start a nested token group + let (token_group, _) = try_tokenize_recursive(input_chars, false, false)?; + output.push(BmaFnToken::TokenList(token_group)); + } + ')' => { + return if !top_level { + Ok((output, ')')) + } else { + Err("Unexpected ')' without opening counterpart.".to_string()) + } + } + ',' if top_fn_level => { + // in case we are collecting something inside a function, a comma is valid delimiter + return Ok((output, ',')); + } + // parse literals, function names + c if is_valid_start_name(c) => { + let name = format!("{c}{}", collect_name(input_chars)); + match name.as_str() { + "abs" => { + let args = collect_fn_arguments(input_chars)?; + output.push(BmaFnToken::Unary( + UnaryOp::Abs, + Box::new(args[0].to_owned()), + )) + } + "ceil" => { + let args = collect_fn_arguments(input_chars)?; + output.push(BmaFnToken::Unary( + UnaryOp::Ceil, + Box::new(args[0].to_owned()), + )) + } + "floor" => { + let args = collect_fn_arguments(input_chars)?; + output.push(BmaFnToken::Unary( + UnaryOp::Floor, + Box::new(args[0].to_owned()), + )) + } + "min" => { + let args = collect_fn_arguments(input_chars)?; + output.push(BmaFnToken::Aggregate(AggregateOp::Min, args)); + } + "max" => { + let args = collect_fn_arguments(input_chars)?; + output.push(BmaFnToken::Aggregate(AggregateOp::Max, args)); + } + "avg" => { + let args = collect_fn_arguments(input_chars)?; + output.push(BmaFnToken::Aggregate(AggregateOp::Avg, args)); + } + _ => { + // Assume it’s a literal + output.push(BmaFnToken::Atomic(Literal::Str(name))); + } + } + } + '0'..='9' => { + let number = format!("{c}{}", collect_number_str(input_chars)); + let int_number = number + .parse::() + .map_err(|_| "Failed to parse number".to_string())?; + output.push(BmaFnToken::Atomic(Literal::Int(int_number))); + } + _ => { + return Err(format!("Unexpected character: '{c}'")); + } + } + } + + if top_level { + Ok((output, '$')) + } else { + Err("Expected ')' to previously encountered opening counterpart.".to_string()) + } +} + +/// Check all whitespaces at the front of the iterator. +fn skip_whitespaces(chars: &mut Peekable) { + while let Some(&c) = chars.peek() { + if c.is_whitespace() { + chars.next(); // Skip the whitespace character + } else { + break; // Stop skipping when a non-whitespace character is found + } + } +} + +/// Check if given char can appear in a name. +fn is_valid_in_name(c: char) -> bool { + c.is_alphanumeric() || c == '_' +} + +/// Check if given char can appear at the beginning of a name. +fn is_valid_start_name(c: char) -> bool { + c.is_alphabetic() || c == '_' +} + +/// Collects a name (e.g., for variables, functions) from the input character iterator. +fn collect_name(input_chars: &mut Peekable) -> String { + let mut name = String::new(); + while let Some(&c) = input_chars.peek() { + if is_valid_in_name(c) { + name.push(c); + input_chars.next(); // consume the character + } else { + break; + } + } + name +} + +/// Collects a number (integer) from the input character iterator. +fn collect_number_str(input_chars: &mut Peekable) -> String { + let mut number_str = String::new(); + while let Some(&c) = input_chars.peek() { + if c.is_ascii_digit() { + number_str.push(c); + input_chars.next(); // consume the character + } else { + break; + } + } + number_str +} + +/// Collects the arguments for a function from the input character iterator. +fn collect_fn_arguments(input_chars: &mut Peekable) -> Result, String> { + skip_whitespaces(input_chars); + + if Some('(') != input_chars.next() { + return Err("Function name must be followed by `(`.".to_string()); + } + + let mut args = Vec::new(); + let mut last_delim = ','; + + while last_delim != ')' { + assert_eq!(last_delim, ','); + let (token_group, last_char) = try_tokenize_recursive(input_chars, false, true)?; + if token_group.is_empty() { + return Err("Function argument cannot be empty.".to_string()); + } + args.push(BmaFnToken::TokenList(token_group)); + last_delim = last_char; + } + + Ok(args) +} + +#[cfg(test)] +mod tests { + use crate::update_fn::enums::{AggregateOp, ArithOp, Literal, UnaryOp}; + use crate::update_fn::tokenizer::{try_tokenize_bma_formula, BmaFnToken}; + + #[test] + fn test_simple_arithmetic() { + let input = "3 + 5 - 2".to_string(); + let result = try_tokenize_bma_formula(input); + assert_eq!( + result, + Ok(vec![ + BmaFnToken::Atomic(Literal::Int(3)), + BmaFnToken::Binary(ArithOp::Add), + BmaFnToken::Atomic(Literal::Int(5)), + BmaFnToken::Binary(ArithOp::Minus), + BmaFnToken::Atomic(Literal::Int(2)) + ]) + ); + } + + #[test] + fn test_function_with_single_argument() { + let input = "abs(5)".to_string(); + let result = try_tokenize_bma_formula(input); + assert_eq!( + result, + Ok(vec![BmaFnToken::Unary( + UnaryOp::Abs, + Box::new(BmaFnToken::TokenList(vec![BmaFnToken::Atomic( + Literal::Int(5) + )])), + )]) + ); + } + + #[test] + fn test_aggregate_function_with_multiple_arguments() { + let input = "min(5, 3)".to_string(); + let result = try_tokenize_bma_formula(input); + assert_eq!( + result, + Ok(vec![BmaFnToken::Aggregate( + AggregateOp::Min, + vec![ + BmaFnToken::TokenList(vec![BmaFnToken::Atomic(Literal::Int(5))]), + BmaFnToken::TokenList(vec![BmaFnToken::Atomic(Literal::Int(3))]) + ] + )]) + ); + } + + #[test] + fn test_nested_function_calls() { + let input = "max(abs(5), ceil(3))".to_string(); + let result = try_tokenize_bma_formula(input); + assert_eq!( + result, + Ok(vec![BmaFnToken::Aggregate( + AggregateOp::Max, + vec![ + BmaFnToken::TokenList(vec![BmaFnToken::Unary( + UnaryOp::Abs, + Box::new(BmaFnToken::TokenList(vec![BmaFnToken::Atomic( + Literal::Int(5) + )])), + )]), + BmaFnToken::TokenList(vec![BmaFnToken::Unary( + UnaryOp::Ceil, + Box::new(BmaFnToken::TokenList(vec![BmaFnToken::Atomic( + Literal::Int(3) + )])), + )]) + ] + )]) + ); + } + + #[test] + fn test_unmatched_parentheses() { + let input = "min(5, 3".to_string(); + let result = try_tokenize_bma_formula(input); + assert!(result.is_err()); + assert_eq!( + result, + Err("Expected ')' to previously encountered opening counterpart.".to_string()) + ); + } + + #[test] + fn test_unexpected_character() { + let input = "5 + @".to_string(); + let result = try_tokenize_bma_formula(input); + assert!(result.is_err()); + assert_eq!(result, Err("Unexpected character: '@'".to_string())); + } + + #[test] + fn test_compound_expression_with_nested_parentheses() { + let input = "3 + (5 * (2 + 1))".to_string(); + let result = try_tokenize_bma_formula(input); + assert_eq!( + result, + Ok(vec![ + BmaFnToken::Atomic(Literal::Int(3)), + BmaFnToken::Binary(ArithOp::Add), + BmaFnToken::TokenList(vec![ + BmaFnToken::Atomic(Literal::Int(5)), + BmaFnToken::Binary(ArithOp::Times), + BmaFnToken::TokenList(vec![ + BmaFnToken::Atomic(Literal::Int(2)), + BmaFnToken::Binary(ArithOp::Add), + BmaFnToken::Atomic(Literal::Int(1)) + ]) + ]) + ]) + ); + } + + #[test] + fn test_function_with_no_arguments_invalid() { + let input = "abs()".to_string(); + let result = try_tokenize_bma_formula(input); + assert!(result.is_err()); + } +}