Skip to content

Commit

Permalink
Add basic evaluation for BmaUpdateFn, refactor parts of code.
Browse files Browse the repository at this point in the history
  • Loading branch information
ondrej33 committed Sep 8, 2024
1 parent 35959c0 commit cefe2af
Show file tree
Hide file tree
Showing 10 changed files with 238 additions and 96 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ serde_json = "1.0"
serde_with = "3.9.0"
serde-xml-rs = "0.6.0"
num-rational = "0.4.2"
num-traits = "0.2.19"
28 changes: 26 additions & 2 deletions src/_impl_bma_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use crate::update_fn::bma_fn_tree::BmaFnUpdate;
use crate::update_fn::parser::parse_bma_formula;
use crate::xml_model::XmlBmaModel;
use biodivine_lib_param_bn::{BooleanNetwork, RegulatoryGraph};
use regex::Regex;
use std::cmp::max;
use std::collections::HashMap;

impl<'de> JsonSerDe<'de> for BmaModel {
Expand Down Expand Up @@ -227,11 +229,17 @@ impl From<XmlBmaModel> for BmaModel {
}

impl BmaModel {
fn canonical_var_name(var: &Variable) -> String {
// Regex that matches non-alphanumeric and non-underscore characters
let re = Regex::new(r"[^0-9a-zA-Z_]").unwrap();
let sanitized_name = re.replace_all(&var.name, "");
format!("v_{}_{}", var.id, sanitized_name)
}

pub fn to_regulatory_graph(&self) -> Result<RegulatoryGraph, String> {
let mut variables_map: HashMap<u32, String> = HashMap::new();
for var in &self.model.variables {
let inserted =
variables_map.insert(var.id, format!("v_{}_{}", var.id, var.name.clone()));
let inserted = variables_map.insert(var.id, BmaModel::canonical_var_name(var));
if inserted.is_some() {
return Err(format!("Variable ID {} is not unique.", var.id));
}
Expand All @@ -240,6 +248,7 @@ impl BmaModel {
let mut graph = RegulatoryGraph::new(variables);

// add regulations
// TODO: decide how to handle "doubled" regulations (of the same vs of different type)
self.model
.relationships
.iter()
Expand All @@ -260,6 +269,11 @@ impl BmaModel {
}

pub fn to_boolean_network(&self) -> Result<BooleanNetwork, String> {
// TODO: for now, we are only allowing conversion of Boolean models (not multi-valued)
if self.get_max_var_level() > 1 {
return Err("Cannot convert multi-valued model to a Boolean network.".to_string());
}

let graph = self.to_regulatory_graph()?;
let bn = BooleanNetwork::new(graph);

Expand Down Expand Up @@ -363,6 +377,16 @@ impl BmaModel {
metadata: HashMap::new(),
})
}

pub fn get_max_var_level(&self) -> u32 {
let mut max_level = 0;
self.model.variables.iter().for_each(|v| {
// just in case, lets check both `range_from` and `range_to`
max_level = max(max_level, v.range_from);
max_level = max(max_level, v.range_to);
});
max_level
}
}

#[cfg(test)]
Expand Down
17 changes: 2 additions & 15 deletions src/bin/load_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,8 @@ fn test_parse_all_models_in_dir(models_dir: &str) {

let result_model = BmaModel::from_json_str(&json_data);
match result_model {
Ok(bma_model) => {
let result_bn = bma_model.to_boolean_network();
match result_bn {
Ok(_) => {
println!("Successfully parsed and converted model: `{model_path_str}`.");
}
Err(e) => {
println!(
"Failed to convert model `{}` to BN: {:?}.",
model_path_str, e
);
}
}

println!("Successfully parsed and converted model: `{model_path_str}`.");
Ok(_) => {
println!("Successfully parsed model `{model_path_str}`.");
}
Err(e) => {
println!("Failed to parse JSON file `{}`: {:?}.", model_path_str, e);
Expand Down
17 changes: 2 additions & 15 deletions src/bin/load_xml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,8 @@ fn test_parse_all_models_in_dir(models_dir: &str) {

let result_model = BmaModel::from_xml_str(&xml_data);
match result_model {
Ok(bma_model) => {
let result_bn = bma_model.to_boolean_network();
match result_bn {
Ok(_) => {
println!("Successfully parsed and converted model: `{model_path_str}`.");
}
Err(e) => {
println!(
"Failed to convert model `{}` to BN: {:?}.",
model_path_str, e
);
}
}

println!("Successfully parsed and converted model: `{model_path_str}`.");
Ok(_) => {
println!("Successfully parsed model `{model_path_str}`.");
}
Err(e) => {
println!("Failed to parse JSON file `{}`: {:?}.", model_path_str, e);
Expand Down
16 changes: 8 additions & 8 deletions src/update_fn/_impl_from_update_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,32 +32,32 @@ impl BmaFnUpdate {
match op {
// AND: map A && B to A * B
BinaryOp::And => {
BmaFnUpdate::mk_arithmetic(left_expr, right_expr, ArithOp::Times)
BmaFnUpdate::mk_arithmetic(left_expr, right_expr, ArithOp::Mult)
}
// OR: map A || B to A + B - A * B
BinaryOp::Or => {
let sum_expr = BmaFnUpdate::mk_arithmetic(
left_expr.clone(),
right_expr.clone(),
ArithOp::Add,
ArithOp::Plus,
);
let prod_expr =
BmaFnUpdate::mk_arithmetic(left_expr, right_expr, ArithOp::Times);
BmaFnUpdate::mk_arithmetic(left_expr, right_expr, ArithOp::Mult);
BmaFnUpdate::mk_arithmetic(sum_expr, prod_expr, ArithOp::Minus)
}
// XOR: map A ^ B to A + B - 2 * (A * B)
BinaryOp::Xor => {
let sum_expr = BmaFnUpdate::mk_arithmetic(
left_expr.clone(),
right_expr.clone(),
ArithOp::Add,
ArithOp::Plus,
);
let prod_expr =
BmaFnUpdate::mk_arithmetic(left_expr, right_expr, ArithOp::Times);
BmaFnUpdate::mk_arithmetic(left_expr, right_expr, ArithOp::Mult);
let two_prod_expr = BmaFnUpdate::mk_arithmetic(
BmaFnUpdate::mk_constant(2),
prod_expr,
ArithOp::Times,
ArithOp::Mult,
);
BmaFnUpdate::mk_arithmetic(sum_expr, two_prod_expr, ArithOp::Minus)
}
Expand All @@ -79,8 +79,8 @@ impl BmaFnUpdate {
ArithOp::Minus,
);
let prod_expr =
BmaFnUpdate::mk_arithmetic(left_expr, right_expr, ArithOp::Times);
BmaFnUpdate::mk_arithmetic(not_left_expr, prod_expr, ArithOp::Add)
BmaFnUpdate::mk_arithmetic(left_expr, right_expr, ArithOp::Mult);
BmaFnUpdate::mk_arithmetic(not_left_expr, prod_expr, ArithOp::Plus)
}
}
}
Expand Down
145 changes: 144 additions & 1 deletion src/update_fn/_impl_to_update_fn.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,152 @@
use crate::update_fn::bma_fn_tree::BmaFnUpdate;
use crate::update_fn::bma_fn_tree::{BmaFnUpdate, Expression};
use crate::update_fn::enums::{AggregateFn, ArithOp, Literal, UnaryFn};
use biodivine_lib_param_bn::FnUpdate;
use num_rational::Rational32;
use num_traits::sign::Signed;
use std::collections::HashMap;

impl BmaFnUpdate {
pub fn to_update_fn(&self) -> FnUpdate {
// TODO: implementation via explicit construction of the function table
todo!()
}

pub fn evaluate_in_valuation(
&self,
valuation: &HashMap<String, Rational32>,
) -> Result<Rational32, String> {
match &self.expression_tree {
Expression::Terminal(Literal::Str(name)) => {
if let Some(value) = valuation.get(name) {
Ok(*value)
} else {
Err(format!("Variable `{name}` not found in the valuation."))
}
}
Expression::Terminal(Literal::Int(value)) => Ok(Rational32::new(*value, 1)),
Expression::Arithmetic(operator, left, right) => {
let left_value = left.evaluate_in_valuation(valuation)?;
let right_value = right.evaluate_in_valuation(valuation)?;
let res = match operator {
ArithOp::Plus => left_value + right_value,
ArithOp::Minus => left_value - right_value,
ArithOp::Mult => left_value * right_value,
ArithOp::Div => left_value / right_value,
};
Ok(res)
}
Expression::Unary(function, child_node) => {
let child_value = child_node.evaluate_in_valuation(valuation)?;
let res = match function {
UnaryFn::Abs => Rational32::abs(&child_value),
UnaryFn::Ceil => Rational32::ceil(&child_value),
UnaryFn::Floor => Rational32::floor(&child_value),
};
Ok(res)
}
Expression::Aggregation(function, arguments) => {
let args_values: Vec<Rational32> = arguments
.iter()
.map(|arg| arg.evaluate_in_valuation(valuation))
.collect::<Result<Vec<Rational32>, String>>()?;
let res = match function {
AggregateFn::Avg => {
let count = args_values.len() as i32;
let sum: Rational32 = args_values.iter().cloned().sum();
sum / Rational32::from_integer(count)
}
AggregateFn::Max => args_values
.iter()
.cloned()
.max()
.expect("List of numbers is empty"),
AggregateFn::Min => args_values
.iter()
.cloned()
.min()
.expect("List of numbers is empty"),
};
Ok(res)
}
}
}
}

#[cfg(test)]
mod tests {
use crate::update_fn::parser::parse_bma_formula;
use num_rational::Rational32;
use std::collections::HashMap;

#[test]
fn test_evaluate_terminal_str() {
let expression = parse_bma_formula("x").unwrap();
let valuation = HashMap::from([("x".to_string(), Rational32::new(5, 1))]);
let result = expression.evaluate_in_valuation(&valuation);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Rational32::new(5, 1));
}

#[test]
fn test_evaluate_terminal_int() {
let expression = parse_bma_formula("7").unwrap();
let valuation = HashMap::new();
let result = expression.evaluate_in_valuation(&valuation);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Rational32::new(7, 1));
}

#[test]
fn test_evaluate_arithmetic_plus() {
let expression = parse_bma_formula("2 + 3").unwrap();
let valuation = HashMap::new();
let result = expression.evaluate_in_valuation(&valuation);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Rational32::new(5, 1));
}

#[test]
fn test_evaluate_arithmetic_mult() {
let expression = parse_bma_formula("4 * x").unwrap();
let valuation = HashMap::from([("x".to_string(), Rational32::new(2, 1))]);
let result = expression.evaluate_in_valuation(&valuation);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Rational32::new(8, 1));
}

#[test]
fn test_evaluate_unary_abs() {
let expression = parse_bma_formula("abs(5 - 10)").unwrap();
let valuation = HashMap::new();
let result = expression.evaluate_in_valuation(&valuation);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Rational32::new(5, 1));
}

#[test]
fn test_evaluate_aggregation_avg() {
let expression = parse_bma_formula("avg(1, 2, 3)").unwrap();
let valuation = HashMap::new();
let result = expression.evaluate_in_valuation(&valuation);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Rational32::new(2, 1));
}

#[test]
fn test_evaluate_aggregation_max() {
let expression = parse_bma_formula("max(1, 4, 3)").unwrap();
let valuation = HashMap::new();
let result = expression.evaluate_in_valuation(&valuation);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Rational32::new(4, 1));
}

#[test]
fn test_evaluate_aggregation_min() {
let expression = parse_bma_formula("min(1, 2 - 4, 3)").unwrap();
let valuation = HashMap::new();
let result = expression.evaluate_in_valuation(&valuation);
assert!(result.is_ok());
assert_eq!(result.unwrap(), Rational32::new(-2, 1));
}
}
14 changes: 7 additions & 7 deletions src/update_fn/bma_fn_tree.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::update_fn::enums::{AggregateOp, ArithOp, Literal, UnaryOp};
use crate::update_fn::enums::{AggregateFn, ArithOp, Literal, UnaryFn};
use crate::update_fn::parser::parse_bma_fn_tokens;
use crate::update_fn::tokenizer::BmaFnToken;
use serde::{Deserialize, Serialize};
Expand All @@ -9,15 +9,15 @@ use std::fmt;
///
/// In particular, a node type can be:
/// - A "terminal" node containing a literal (variable, constant).
/// - A "unary" node with a `UnaryOp` and a sub-expression.
/// - A "unary" node with a `UnaryFn` and a sub-expression.
/// - A binary "arithmetic" node, with a `BinaryOp` and two sub-expressions.
/// - An "aggregation" node with a `AggregateOp` op and a list of sub-expressions.
/// - An "aggregation" node with a `AggregateFn` op and a list of sub-expressions.
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
pub enum Expression {
Terminal(Literal),
Unary(UnaryOp, Box<BmaFnUpdate>),
Unary(UnaryFn, Box<BmaFnUpdate>),
Arithmetic(ArithOp, Box<BmaFnUpdate>, Box<BmaFnUpdate>),
Aggregation(AggregateOp, Vec<Box<BmaFnUpdate>>),
Aggregation(AggregateFn, Vec<Box<BmaFnUpdate>>),
}

/// A single node in a syntax tree of a FOL formula.
Expand All @@ -42,7 +42,7 @@ impl BmaFnUpdate {
/// Create a "unary" [BmaFnUpdate] from the given arguments.
///
/// See also [Expression::Unary].
pub fn mk_unary(child: BmaFnUpdate, op: UnaryOp) -> BmaFnUpdate {
pub fn mk_unary(child: BmaFnUpdate, op: UnaryFn) -> BmaFnUpdate {
let subform_str = format!("{op}({child})");
BmaFnUpdate {
function_str: subform_str,
Expand Down Expand Up @@ -86,7 +86,7 @@ impl BmaFnUpdate {
}

/// Create a [BmaFnUpdate] representing an aggregation operator applied to given arguments.
pub fn mk_aggregation(op: AggregateOp, inner_nodes: Vec<BmaFnUpdate>) -> BmaFnUpdate {
pub fn mk_aggregation(op: AggregateFn, inner_nodes: Vec<BmaFnUpdate>) -> BmaFnUpdate {
let max_height = inner_nodes
.iter()
.map(|node| node.height)
Expand Down
Loading

0 comments on commit cefe2af

Please sign in to comment.