diff --git a/src/_impl_bdd_variable_set.rs b/src/_impl_bdd_variable_set.rs index 69719fa..599678c 100644 --- a/src/_impl_bdd_variable_set.rs +++ b/src/_impl_bdd_variable_set.rs @@ -26,9 +26,39 @@ impl BddVariableSet { /// /// *Panics:* `vars` must contain unique names which are allowed as variable names. pub fn new(vars: &[&str]) -> BddVariableSet { - let mut builder = BddVariableSetBuilder::new(); - builder.make_variables(vars); - builder.build() + let num_vars = vars.len(); + if num_vars >= ((u16::MAX - 1) as usize) { + panic!( + "Too many BDD variables. There can be at most {} variables.", + u16::MAX - 1 + ) + } + let var_names: Vec = vars + .iter() + .map(|name| { + if name.chars().any(|c| NOT_IN_VAR_NAME.contains(&c)) { + panic!( + "Variable name {} is invalid. Cannot use {:?}", + name, NOT_IN_VAR_NAME + ); + } + name.to_string() + }) + .collect(); + let var_index_mapping: HashMap = vars + .iter() + .enumerate() + .map(|(id, name)| (name.to_string(), id as u16)) + .collect(); + + if var_index_mapping.len() != var_names.len() { + panic!("Existing duplicated BDD variable."); + } + BddVariableSet { + num_vars: num_vars as u16, + var_names, + var_index_mapping, + } } /// Return the number of variables in this set. @@ -591,4 +621,35 @@ mod tests { let ctx = BddVariableSet::new(&[]); assert_eq!("[]", ctx.to_string()); } + + #[test] + fn bdd_new_performance() { + fn old_version(vars: &[&str]) -> BddVariableSet { + let mut builder = BddVariableSetBuilder::new(); + builder.make_variables(vars); + builder.build() + } + // validate correctness + let ctx = BddVariableSet::new(&["a", "b", "x", "c", "y"]); + let ctx_old = old_version(&["a", "b", "x", "c", "y"]); + assert_eq!(ctx.num_vars, ctx_old.num_vars); + assert_eq!(ctx.var_names, ctx_old.var_names); + assert_eq!(ctx.var_index_mapping, ctx_old.var_index_mapping); + // validate performance + use std::time::SystemTime; + let n = 1000; + let t_bgn1 = SystemTime::now(); + for _ in 0..n { + let _ctx = BddVariableSet::new(&["a", "b", "x", "c", "y"]); + } + let t1 = t_bgn1.elapsed().unwrap(); + println!("New BddVariableSet::new runtime: {t1:?}"); + let t_bgn2 = SystemTime::now(); + for _ in 0..n { + let _ctx = old_version(&["a", "b", "x", "c", "y"]); + } + let t2 = t_bgn2.elapsed().unwrap(); + println!("Old BddVariableSet::new runtime: {t2:?}"); + assert!(t1 < t2); + } }