From a8e43a494dc5a499e8f3bcc295768dd8ec8d771d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Chud=C3=AD=C4=8Dek?= Date: Thu, 14 Dec 2023 14:43:54 +0100 Subject: [PATCH] refactor: rewritten the reachability benchmark --- examples/reachability.rs | 4 + examples/rewritten_reachability.rs | 27 ++++ src/benchmarks/mod.rs | 2 + src/benchmarks/rewritten_reachability.rs | 165 +++++++++++++++++++++++ src/benchmarks/utils.rs | 67 +++++++++ src/prototype/utils.rs | 1 + src/symbolic_domains/symbolic_domain.rs | 3 + src/update/update_fn.rs | 9 +- tests/some_test.rs | 2 +- 9 files changed, 275 insertions(+), 5 deletions(-) create mode 100644 examples/rewritten_reachability.rs create mode 100644 src/benchmarks/rewritten_reachability.rs create mode 100644 src/benchmarks/utils.rs diff --git a/examples/reachability.rs b/examples/reachability.rs index 3d3a7b6..a63c81e 100644 --- a/examples/reachability.rs +++ b/examples/reachability.rs @@ -8,6 +8,8 @@ fn main() { let representation = args[1].clone(); let sbml_path = args[2].clone(); + let now = std::time::Instant::now(); + match representation.as_str() { "unary" => reachability_benchmark::(sbml_path.as_str()), "binary" => reachability_benchmark::>(sbml_path.as_str()), @@ -15,4 +17,6 @@ fn main() { "gray" | "grey" => reachability_benchmark::>(sbml_path.as_str()), _ => panic!("Unknown representation: {}.", representation), } + + println!("Time: {}s", now.elapsed().as_secs()); } diff --git a/examples/rewritten_reachability.rs b/examples/rewritten_reachability.rs new file mode 100644 index 0000000..3d66718 --- /dev/null +++ b/examples/rewritten_reachability.rs @@ -0,0 +1,27 @@ +use biodivine_lib_logical_models::{ + benchmarks::rewritten_reachability::reachability_benchmark, + prelude::symbolic_domain::{ + BinaryIntegerDomain, GrayCodeIntegerDomain, PetriNetIntegerDomain, UnaryIntegerDomain, + }, +}; +// use biodivine_lib_logical_models::prelude::old_symbolic_domain::{ +// BinaryIntegerDomain, GrayCodeIntegerDomain, PetriNetIntegerDomain, UnaryIntegerDomain, +// }; + +fn main() { + let args = std::env::args().collect::>(); + let representation = args[1].clone(); + let sbml_path = args[2].clone(); + + let now = std::time::Instant::now(); + + match representation.as_str() { + "unary" => reachability_benchmark::(sbml_path.as_str()), + "binary" => reachability_benchmark::>(sbml_path.as_str()), + "petri_net" => reachability_benchmark::(sbml_path.as_str()), + "gray" | "grey" => reachability_benchmark::>(sbml_path.as_str()), + _ => panic!("Unknown representation: {}.", representation), + } + + println!("Time: {}s", now.elapsed().as_secs()); +} diff --git a/src/benchmarks/mod.rs b/src/benchmarks/mod.rs index 4cd40ea..f699ead 100644 --- a/src/benchmarks/mod.rs +++ b/src/benchmarks/mod.rs @@ -1 +1,3 @@ pub mod reachability; +pub mod rewritten_reachability; +mod utils; diff --git a/src/benchmarks/rewritten_reachability.rs b/src/benchmarks/rewritten_reachability.rs new file mode 100644 index 0000000..c8edf63 --- /dev/null +++ b/src/benchmarks/rewritten_reachability.rs @@ -0,0 +1,165 @@ +use biodivine_lib_bdd::Bdd; +use std::fmt::Debug; + +use crate::{ + benchmarks::utils::{count_states, log_percent, pick_state_bdd, unit_vertex_set}, + prelude::find_start_of, + symbolic_domains::symbolic_domain::SymbolicDomainOrd, + update::update_fn::SmartSystemUpdateFn as RewrittenSmartSystemUpdateFn, +}; + +pub fn reachability_benchmark + Debug>(sbml_path: &str) { + let smart_system_update_fn = { + let mut xml = xml::reader::EventReader::new(std::io::BufReader::new( + std::fs::File::open(sbml_path).expect("should be able to open file"), + )); + + find_start_of(&mut xml, "listOfTransitions") + .expect("Cannot find transitions in the SBML file."); + + RewrittenSmartSystemUpdateFn::::try_from_xml(&mut xml) + .expect("Loading system fn update failed.") + }; + + let unit = unit_vertex_set(&smart_system_update_fn); + let system_var_count = smart_system_update_fn + .variables_transition_relation_and_domain + .len(); + println!( + "Variables: {}, expected states {}", + system_var_count, + 1 << system_var_count + ); + println!( + "Computed state count: {}", + count_states(&smart_system_update_fn, &unit) + ); + let mut universe = unit.clone(); + while !universe.is_false() { + let mut weak_scc = pick_state_bdd(&smart_system_update_fn, &universe); + loop { + let bwd_reachable = reach_bwd(&smart_system_update_fn, &weak_scc, &universe); + let fwd_bwd_reachable = reach_fwd(&smart_system_update_fn, &bwd_reachable, &universe); + + // FWD/BWD reachable set is not a subset of weak SCC, meaning the SCC can be expanded. + if !fwd_bwd_reachable.imp(&weak_scc).is_true() { + println!( + " + SCC increased to (states={}, size={})", + count_states(&smart_system_update_fn, &weak_scc), + weak_scc.size() + ); + weak_scc = fwd_bwd_reachable; + } else { + break; + } + } + println!( + " + Found weak SCC (states={}, size={})", + count_states(&smart_system_update_fn, &weak_scc), + weak_scc.size() + ); + // Remove the SCC from the universe set and start over. + universe = universe.and_not(&weak_scc); + println!( + " + Remaining states: {}/{}", + count_states(&smart_system_update_fn, &universe), + count_states(&smart_system_update_fn, &unit), + ); + } +} + +/// Compute the set of vertices that are forward-reachable from the `initial` set. +/// +/// The result BDD contains a vertex `x` if and only if there is a (possibly zero-length) path +/// from some vertex `x' \in initial` into `x`, i.e. `x' -> x`. +pub fn reach_fwd + Debug>( + system: &RewrittenSmartSystemUpdateFn, + initial: &Bdd, + universe: &Bdd, +) -> Bdd { + // The list of system variables, sorted in descending order (i.e. opposite order compared + // to the ordering inside BDDs). + let sorted_variables = system + .variables_transition_relation_and_domain + .iter() + .map(|(var_name, _)| var_name) + .collect::>(); + let mut result = initial.clone(); + println!( + "Start forward reachability: (states={}, size={})", + count_states(system, &result), + result.size() + ); + 'fwd: loop { + for var in sorted_variables.iter().rev() { + let successors = system.successors_async(var.as_str(), &result); + + // Should be equivalent to "successors \not\subseteq result". + if !successors.imp(&result).is_true() { + result = result.or(&successors); + println!( + " >> (progress={:.2}%%, states={}, size={})", + log_percent(&result, universe), + count_states(system, &result), + result.size() + ); + continue 'fwd; + } + } + + // No further successors were computed across all variables. We are done. + println!( + " >> Done. (states={}, size={})", + count_states(system, &result), + result.size() + ); + return result; + } +} + +/// Compute the set of vertices that are backward-reachable from the `initial` set. +/// +/// The result BDD contains a vertex `x` if and only if there is a (possibly zero-length) path +/// from `x` into some vertex `x' \in initial`, i.e. `x -> x'`. +pub fn reach_bwd + Debug>( + system: &RewrittenSmartSystemUpdateFn, + initial: &Bdd, + universe: &Bdd, +) -> Bdd { + let sorted_variables = system + .variables_transition_relation_and_domain + .iter() + .map(|(var_name, _)| var_name) + .collect::>(); + let mut result = initial.clone(); + println!( + "Start backward reachability: (states={}, size={})", + count_states(system, &result), + result.size() + ); + 'bwd: loop { + for var in sorted_variables.iter().rev() { + let predecessors = system.predecessors_async(var.as_str(), result.clone()); + + // Should be equivalent to "predecessors \not\subseteq result". + if !predecessors.imp(&result).is_true() { + result = result.or(&predecessors); + println!( + " >> (progress={:.2}%%, states={}, size={})", + log_percent(&result, universe), + count_states(system, &result), + result.size() + ); + continue 'bwd; + } + } + + // No further predecessors were computed across all variables. We are done. + println!( + " >> Done. (states={}, size={})", + count_states(system, &result), + result.size() + ); + return result; + } +} diff --git a/src/benchmarks/utils.rs b/src/benchmarks/utils.rs new file mode 100644 index 0000000..44d4a90 --- /dev/null +++ b/src/benchmarks/utils.rs @@ -0,0 +1,67 @@ +use std::fmt::Debug; + +use biodivine_lib_bdd::{Bdd, BddPartialValuation}; + +use crate::{ + symbolic_domains::symbolic_domain::SymbolicDomainOrd, + update::update_fn::SmartSystemUpdateFn as RewrittenSmartSystemUpdateFn, +}; + +pub fn states>( + system: &RewrittenSmartSystemUpdateFn, + set: &Bdd, +) -> f64 { + let symbolic_var_count = system.bdd_variable_set.num_vars() as i32; + // TODO: + // Here we assume that exactly half of the variables are primed, which may not be true + // in the future, but should be good enough for now. + assert_eq!(symbolic_var_count % 2, 0); + let primed_vars = symbolic_var_count / 2; + set.cardinality() / 2.0f64.powi(primed_vars) +} + +pub fn unit_vertex_set>( + system: &RewrittenSmartSystemUpdateFn, +) -> Bdd { + system + .variables_transition_relation_and_domain + .iter() + .fold(system.bdd_variable_set.mk_true(), |acc, (_, var_info)| { + acc.and(&var_info.domain.unit_collection(&system.bdd_variable_set)) + }) +} + +/// Compute an (approximate) count of state in the given `set` using the encoding of `system`. +pub fn count_states + Debug>( + system: &RewrittenSmartSystemUpdateFn, + set: &Bdd, +) -> f64 { + let symbolic_var_count = system.variables_transition_relation_and_domain.len() as i32; + set.cardinality() / 2.0f64.powi(symbolic_var_count) +} + +/// Compute a [Bdd] which represents a single (un-primed) state within the given symbolic `set`. +pub fn pick_state_bdd + Debug>( + system: &RewrittenSmartSystemUpdateFn, + set: &Bdd, +) -> Bdd { + // Unfortunately, this is now a bit more complicated than it needs to be, because + // we have to ignore the primed variables, but it shouldn't bottleneck anything outside of + // truly extreme cases. + let standard_variables = system + .variables_transition_relation_and_domain + .iter() + .flat_map(|transition| transition.1.domain.raw_bdd_variables()); + let valuation = set + .sat_witness() + .expect("Cannot pick state from an empty set."); + let mut state_data = BddPartialValuation::empty(); + for var in standard_variables { + state_data.set_value(var, valuation.value(var)) + } + system.bdd_variable_set.mk_conjunctive_clause(&state_data) +} + +pub fn log_percent(set: &Bdd, universe: &Bdd) -> f64 { + set.cardinality().log2() / universe.cardinality().log2() * 100.0 +} diff --git a/src/prototype/utils.rs b/src/prototype/utils.rs index 65aa3af..0517ffb 100644 --- a/src/prototype/utils.rs +++ b/src/prototype/utils.rs @@ -11,6 +11,7 @@ use xml::{ }; use crate::prototype::symbolic_domain::SymbolicDomain; +use crate::symbolic_domains::symbolic_domain::SymbolicDomain as RewrittenSymbolicDomain; use super::{SmartSystemUpdateFn, UpdateFn}; diff --git a/src/symbolic_domains/symbolic_domain.rs b/src/symbolic_domains/symbolic_domain.rs index 33ccfd2..a57a7a8 100644 --- a/src/symbolic_domains/symbolic_domain.rs +++ b/src/symbolic_domains/symbolic_domain.rs @@ -267,6 +267,7 @@ impl SymbolicDomainOrd for UnaryIntegerDomain { } } +#[derive(Debug)] pub struct PetriNetIntegerDomain { /// invariant: sorted variables: Vec, @@ -355,6 +356,7 @@ impl SymbolicDomainOrd for PetriNetIntegerDomain { } } +#[derive(Debug)] pub struct BinaryIntegerDomain { /// invariant: sorted variables: Vec, @@ -460,6 +462,7 @@ impl SymbolicDomainOrd for BinaryIntegerDomain { } } +#[derive(Debug)] pub struct GrayCodeIntegerDomain { /// invariant: sorted variables: Vec, diff --git a/src/update/update_fn.rs b/src/update/update_fn.rs index 317b203..34c18c4 100644 --- a/src/update/update_fn.rs +++ b/src/update/update_fn.rs @@ -326,12 +326,13 @@ where } } -struct VarInfo +pub struct VarInfo +// todo do not keep pub; just for benchmarks where D: SymbolicDomain, { primed_name: String, - domain: D, + pub domain: D, // todo do not keep pub; just for benchmarks primed_domain: D, transition_relation: Bdd, _marker: std::marker::PhantomData, @@ -342,8 +343,8 @@ where D: SymbolicDomain, { /// ordered by variable name // todo add a method to get the update function by name (hash map or binary search) - variables_transition_relation_and_domain: Vec<(String, VarInfo)>, - bdd_variable_set: BddVariableSet, + pub variables_transition_relation_and_domain: Vec<(String, VarInfo)>, // todo do not keep pub; just here for benchmarking + pub bdd_variable_set: BddVariableSet, // todo do not keep pub; just here for testing _marker: std::marker::PhantomData, } diff --git a/tests/some_test.rs b/tests/some_test.rs index 6d41ea9..9300a22 100644 --- a/tests/some_test.rs +++ b/tests/some_test.rs @@ -745,7 +745,7 @@ fn predecessors_consistency_check() { assert!(initial_state.are_same(&the_four), "initial states are same"); - let transitioned = the_four.successors_async(variable, initial_state); + let transitioned = the_four.predecessors_async(variable, initial_state); assert!(transitioned.are_same(&the_four), "all are same");