diff --git a/pyschlandals/example.py b/pyschlandals/example.py deleted file mode 100644 index 04df0e0..0000000 --- a/pyschlandals/example.py +++ /dev/null @@ -1,9 +0,0 @@ -from pyschlandals.search import exact -from pyschlandals.compiler import compile -from pyschlandals import BranchingHeuristic - -filename = '../tests/instances/bayesian_networks/asia_xray_false.cnf' -print(exact(filename, BranchingHeuristic.MinInDegree)) - -dac = compile(filename, BranchingHeuristic.MinInDegree) -print(dac.get_circuit_probability()) \ No newline at end of file diff --git a/pyschlandals/example/compile.py b/pyschlandals/example/compile.py index eaafa5e..e71b341 100644 --- a/pyschlandals/example/compile.py +++ b/pyschlandals/example/compile.py @@ -1,8 +1,26 @@ -from pyschlandals import BranchingHeuristic -from pyschlandals.compiler import compile +from pyschlandals.pwmc import PyProblem -import sys +problem = PyProblem() +problem.add_distribution([0.2, 0.8]) +problem.add_distribution([0.3, 0.7]) +problem.add_distribution([0.4, 0.6]) +problem.add_distribution([0.1, 0.9]) +problem.add_distribution([0.5, 0.5]) -if __name__ == '__main__': - dac = compile(sys.argv[1], BranchingHeuristic.MinInDegree) - print(dac.get_circuit_probability()) \ No newline at end of file +problem.add_clause([11, -1]) +problem.add_clause([12, -2]) +problem.add_clause([13, -11, -3]) +problem.add_clause([13, -12, -5]) +problem.add_clause([14, -11, -4]) +problem.add_clause([14, -12, -6]) +problem.add_clause([15, -13, -7]) +problem.add_clause([15, -14, -9]) +problem.add_clause([16, -13, -8]) +problem.add_clause([16, -14, -10]) +problem.add_clause([-15]) + +# For one shot compilation/evaluation, just use compile(): +print(problem.compile()) + +# However, you might want to store the compiled AC or visualize it as a DOT file +problem.compile(fdac="out.ac", dotfile="ac.dot") diff --git a/pyschlandals/example/search.py b/pyschlandals/example/search.py deleted file mode 100644 index 03916ab..0000000 --- a/pyschlandals/example/search.py +++ /dev/null @@ -1,8 +0,0 @@ -from pyschlandals import BranchingHeuristic -from pyschlandals.search import exact - -import sys - -if __name__ == '__main__': - proba = exact(sys.argv[1], BranchingHeuristic.MinInDegree) - print(proba) \ No newline at end of file diff --git a/pyschlandals/example/simple.py b/pyschlandals/example/simple.py new file mode 100644 index 0000000..5efa54b --- /dev/null +++ b/pyschlandals/example/simple.py @@ -0,0 +1,22 @@ +from pyschlandals.pwmc import PyProblem + +problem = PyProblem() +problem.add_distribution([0.2, 0.8]) +problem.add_distribution([0.3, 0.7]) +problem.add_distribution([0.4, 0.6]) +problem.add_distribution([0.1, 0.9]) +problem.add_distribution([0.5, 0.5]) + +problem.add_clause([11, -1]) +problem.add_clause([12, -2]) +problem.add_clause([13, -11, -3]) +problem.add_clause([13, -12, -5]) +problem.add_clause([14, -11, -4]) +problem.add_clause([14, -12, -6]) +problem.add_clause([15, -13, -7]) +problem.add_clause([15, -14, -9]) +problem.add_clause([16, -13, -8]) +problem.add_clause([16, -14, -10]) +problem.add_clause([-15]) + +print(problem.solve()) diff --git a/pyschlandals/example/train.csv b/pyschlandals/example/train.csv new file mode 100644 index 0000000..7d800db --- /dev/null +++ b/pyschlandals/example/train.csv @@ -0,0 +1,5 @@ +Instance,Probability +train_input/asia_xray_true.cnf,0.11029 +train_input/asia_xray_false.cnf,0.88971 +train_input/asia_dyspnea_true.cnf,0.435971 +train_input/asia_dyspnea_false.cnf,0.564029 diff --git a/pyschlandals/example/train.py b/pyschlandals/example/train.py new file mode 100644 index 0000000..131e2f4 --- /dev/null +++ b/pyschlandals/example/train.py @@ -0,0 +1,4 @@ +from pyschlandals.learn import learn, PyLearnParameters + +params = PyLearnParameters() +learn("train.csv", params) diff --git a/pyschlandals/src/lib.rs b/pyschlandals/src/lib.rs index 6ec990e..9180569 100644 --- a/pyschlandals/src/lib.rs +++ b/pyschlandals/src/lib.rs @@ -3,9 +3,12 @@ use pyo3::Python; use std::path::PathBuf; use schlandals::*; +mod train; +use train::*; + #[pyclass] -#[derive(Clone)] -enum PyBranching { +#[derive(Clone, Copy)] +pub enum PyBranching { MinInDegree, MinOutDegree, MaxDegree, @@ -21,57 +24,80 @@ fn get_branching_from_pybranching(branching: PyBranching) -> Branching { #[pyclass] #[derive(Clone)] -enum PyLoss { - MAE, - MSE, +struct PyProblem { + distributions: Vec>, + clauses: Vec>, + branching: PyBranching, + epsilon: f64, + timeout: u64, + memory_limit: u64, + statistics: bool, } -fn get_loss_from_pyloss(loss: PyLoss) -> Loss { - match loss { - PyLoss::MAE => Loss::MAE, - PyLoss::MSE => Loss::MSE, +#[pymethods] +impl PyProblem { + + #[new] + pub fn new(branching: Option, epsilon: Option, timeout: Option, memory_limit: Option, statistics: Option) -> Self { + Self { + distributions: vec![], + clauses: vec![], + branching: if let Some(b) = branching { b } else { PyBranching::MinInDegree }, + epsilon: if let Some(e) = epsilon { e } else { 0.0 }, + timeout: if let Some(to) = timeout { to } else { u64::MAX }, + memory_limit: if let Some(limit) = memory_limit { limit } else { u64::MAX }, + statistics: if let Some(s) = statistics { s } else { false }, + } } -} -// TODO: Find how to make the python binding to take into account that the tensors are a feature -// that is not enabled by default -#[pyclass] -#[derive(Clone)] -enum PySemiring { - Probability, - //Tensor, -} + pub fn add_distribution(&mut self, distribution: Vec) { + self.distributions.push(distribution); + } -fn get_semiring_from_pysemiring(semiring: PySemiring) -> Semiring { - match semiring { - PySemiring::Probability => Semiring::Probability, - //PySemiring::Tensor => Semiring::Tensor, + pub fn add_clause(&mut self, clause: Vec) { + self.clauses.push(clause); } -} -#[pyclass] -#[derive(Clone)] -enum PyOptimizer { - SGD, - Adam, -} + pub fn solve(&self) -> Option { + match schlandals::solve_from_problem(&self.distributions, &self.clauses, get_branching_from_pybranching(self.branching), self.epsilon, Some(self.memory_limit), self.timeout, self.statistics) { + Ok(p) => Some(p.to_f64()), + Err(e) => { + println!("{:?}", e); + None + } + } + } -fn get_optimizer_from_pyoptimizer(optimizer: PyOptimizer) -> Optimizer { - match optimizer { - PyOptimizer::SGD => Optimizer::SGD, - PyOptimizer::Adam => Optimizer::Adam, + pub fn compile(&self, fdac: Option, dotfile: Option) -> Option { + match schlandals::compile_from_problem(&self.distributions, + &self.clauses, + get_branching_from_pybranching(self.branching), + self.epsilon, + Some(self.memory_limit), + self.timeout, + self.statistics, + if let Some(path) = fdac { Some(PathBuf::from(path)) } else { None }, + if let Some(path) = dotfile { Some(PathBuf::from(path)) } else { None },) { + Ok(p) => Some(p.to_f64()), + Err(e) => { + println!("{:?}", e); + None + } + } } } + #[pyfunction] #[pyo3(name = "search")] -fn pysearch(file: String, branching: PyBranching, epsilon: Option, memory_limit: Option) -> Option { +fn pysearch(file: String, branching: PyBranching, epsilon: Option, memory_limit: Option, timeout: Option) -> Option { let e = if epsilon.is_none() { 0.0 } else { epsilon.unwrap() }; - match schlandals::search(PathBuf::from(file), get_branching_from_pybranching(branching), false, memory_limit, e) { + let to = if timeout.is_none() { u64::MAX } else { timeout.unwrap() }; + match schlandals::search(PathBuf::from(file), get_branching_from_pybranching(branching), false, memory_limit, e, to) { Err(_) => None, Ok(p) => Some(p.to_f64()), } @@ -79,42 +105,35 @@ fn pysearch(file: String, branching: PyBranching, epsilon: Option, memory_l #[pyfunction] #[pyo3(name = "compile")] -fn pycompile(file: String, branching: PyBranching, epsilon: Option, output_circuit: Option, output_dot: Option) -> Option { +fn pycompile(file: String, branching: PyBranching, epsilon: Option, output_circuit: Option, output_dot: Option, timeout: Option) -> Option { let fdac = if let Some(file) = output_circuit { Some(PathBuf::from(file)) } else { None }; let fdot = if let Some(file) = output_dot { Some(PathBuf::from(file)) } else { None }; let e = if let Some(e) = epsilon { e } else { 0.0 }; - match schlandals::compile(PathBuf::from(file), get_branching_from_pybranching(branching), fdac, fdot, e) { + let to = if timeout.is_none() { u64::MAX } else { timeout.unwrap() }; + match schlandals::compile(PathBuf::from(file), get_branching_from_pybranching(branching), fdac, fdot, e, to) { Err(_) => None, Ok(p) => Some(p.to_f64()), } } -#[pyfunction] -#[pyo3(name = "learn")] -fn pylearn(train_file: String, branching: PyBranching, learning_rate: f64, nepochs: usize, log: bool, timeout: u64, epsilon: f64, loss: PyLoss, jobs: usize, semiring: PySemiring, optimizer: PyOptimizer, test_file: Option, outfolder: Option) { - let b = get_branching_from_pybranching(branching); - let l = get_loss_from_pyloss(loss); - let s = get_semiring_from_pysemiring(semiring); - let o = get_optimizer_from_pyoptimizer(optimizer); - let train = PathBuf::from(train_file); - let test = if test_file.is_none() { - None - } else { - Some(PathBuf::from(test_file.unwrap())) - }; - schlandals::learn(train, test, b, outfolder, learning_rate, nepochs, log, timeout, epsilon, l, jobs, s, o); -} +#[pymodule] +#[pyo3(name="pwmc")] +fn pwmc_submodule(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> { + let module = PyModule::new(py, "pwmc")?; + module.add_class::()?; + module.add_function(wrap_pyfunction!(pycompile, module)?)?; + module.add_function(wrap_pyfunction!(pysearch, module)?)?; + parent_module.add_submodule(module)?; + py.import("sys")?.getattr("modules")?.set_item("pyschlandals.pwmc", module)?; + Ok(()) +} /// Base module for pyschlandals #[pymodule] -fn pyschlandals(_py: Python, m: &PyModule) -> PyResult<()> { +fn pyschlandals(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_class::()?; - m.add_function(wrap_pyfunction!(pylearn, m)?).unwrap(); - m.add_function(wrap_pyfunction!(pysearch, m)?).unwrap(); - m.add_function(wrap_pyfunction!(pycompile, m)?).unwrap(); + pwmc_submodule(py, m)?; + train::learn_submodule(py, m)?; Ok(()) } diff --git a/pyschlandals/src/train.rs b/pyschlandals/src/train.rs new file mode 100644 index 0000000..e188ae9 --- /dev/null +++ b/pyschlandals/src/train.rs @@ -0,0 +1,148 @@ +use pyo3::prelude::*; +use super::*; + +use schlandals::learning::LearnParameters; + +#[pyclass] +#[derive(Clone)] +pub enum PyLoss { + MAE, + MSE, +} + +fn get_loss_from_pyloss(loss: PyLoss) -> Loss { + match loss { + PyLoss::MAE => Loss::MAE, + PyLoss::MSE => Loss::MSE, + } +} + + +#[pyclass] +#[derive(Clone)] +pub struct PyLearnParameters { + /// The initial learning rate + lr: f64, + /// The number of epochs + nepochs: usize, + /// The timeout used for the compilations, in seconds + compilation_timeout: u64, + /// The timeout used for the training loop, in seconds + learn_timeout: u64, + /// The loss function + loss: PyLoss, + /// The optimizer + optimizer: PyOptimizer, + /// The learning rate decay + lr_drop: f64, + /// The number of epochs after which the learning rate is dropped + epoch_drop: usize, + /// The error threshold under which the training is stopped + early_stop_threshold: f64, + /// The minimum delta between two epochs to consider that the training is still improving + early_stop_delta: f64, + /// The number of epochs to wait before stopping the training if the loss is not improving + patience: usize, +} + +#[pymethods] +impl PyLearnParameters { + + #[new] + pub fn new(lr: Option, nepochs: Option, compilation_timeout: Option, learn_timeout: Option, loss: Option, optimizer: Option, lr_drop: Option, epoch_drop: Option, early_stop_threshold: Option, early_stop_delta: Option, patience: Option) -> Self { + Self { + lr: if let Some(v) = lr { v } else { 0.3 }, + nepochs: if let Some(v) = nepochs { v } else { 6000 }, + compilation_timeout: if let Some(v) = compilation_timeout { v } else { u64::MAX }, + learn_timeout: if let Some(v) = learn_timeout { v } else { u64::MAX }, + loss: if let Some(v) = loss { v } else { PyLoss::MAE }, + optimizer: if let Some(v) = optimizer { v } else { PyOptimizer::Adam }, + lr_drop: if let Some(v) = lr_drop { v } else { 0.75 }, + epoch_drop: if let Some(v) = epoch_drop { v } else { 100 }, + early_stop_threshold: if let Some(v) = early_stop_threshold { v } else { 0.0001 }, + early_stop_delta: if let Some(v) = early_stop_delta { v } else { 0.00001 }, + patience: if let Some(v) = patience { v } else { 5 }, + } + } +} + +fn get_param_from_pyparam(param: PyLearnParameters) -> LearnParameters { + LearnParameters::new(param.lr, + param.nepochs, + param.compilation_timeout, + param.learn_timeout, + get_loss_from_pyloss(param.loss), + get_optimizer_from_pyoptimizer(param.optimizer), + param.lr_drop, + param.epoch_drop, + param.early_stop_threshold, + param.early_stop_delta, + param.patience) +} + +// TODO: Find how to make the python binding to take into account that the tensors are a feature +// that is not enabled by default +#[pyclass] +#[derive(Clone)] +pub enum PySemiring { + Probability, + //Tensor, +} + +fn get_semiring_from_pysemiring(semiring: PySemiring) -> Semiring { + match semiring { + PySemiring::Probability => Semiring::Probability, + //PySemiring::Tensor => Semiring::Tensor, + } +} + +#[pyclass] +#[derive(Clone)] +pub enum PyOptimizer { + SGD, + Adam, +} + +fn get_optimizer_from_pyoptimizer(optimizer: PyOptimizer) -> Optimizer { + match optimizer { + PyOptimizer::SGD => Optimizer::SGD, + PyOptimizer::Adam => Optimizer::Adam, + } +} + +#[pyfunction] +#[pyo3(name = "learn")] +pub fn pylearn(train_file: String, param: PyLearnParameters, branching: Option, semiring: Option, log: Option, epsilon: Option, jobs: Option, test_file: Option, outfolder: Option) { + let train = PathBuf::from(train_file); + let test = if test_file.is_none() { + None + } else { + Some(PathBuf::from(test_file.unwrap())) + }; + schlandals::learn(train, + test, + get_branching_from_pybranching(if let Some(v) = branching { v } else { PyBranching::MinInDegree }), + outfolder, + if let Some(v) = log { v } else { false }, + if let Some(v) = epsilon { v } else { 0.0 }, + if let Some(v) = jobs { v } else { 1 }, + get_semiring_from_pysemiring(if let Some(v) = semiring { v } else { PySemiring::Probability }), + get_param_from_pyparam(param)) +} + +#[pymodule] +#[pyo3(name="learn")] +pub fn learn_submodule(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> { + let module = PyModule::new(py, "learn")?; + + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_function(wrap_pyfunction!(pylearn, module)?)?; + + parent_module.add_submodule(module)?; + py.import("sys")?.getattr("modules")?.set_item("pyschlandals.learn", module)?; + Ok(()) +} + diff --git a/src/learning/mod.rs b/src/learning/mod.rs index d08b0cf..bb7e662 100644 --- a/src/learning/mod.rs +++ b/src/learning/mod.rs @@ -32,7 +32,7 @@ pub struct LearnParameters { /// The number of epochs nepochs: usize, /// The timeout used for the compilations, in seconds - timeout: u64, + compilation_timeout: u64, /// The timeout used for the training loop, in seconds learn_timeout: u64, /// The loss function @@ -53,8 +53,8 @@ pub struct LearnParameters { impl LearnParameters { - pub fn new(lr: f64, nepochs: usize, timeout: u64, learn_timeout: u64, loss: Loss, optimizer: Optimizer, lr_drop: f64, epoch_drop: usize, early_stop_threshold: f64, early_stop_delta: f64, patience: usize) -> Self { - Self { lr, nepochs, timeout, learn_timeout, loss, optimizer, lr_drop, epoch_drop, early_stop_threshold, early_stop_delta, patience } + pub fn new(lr: f64, nepochs: usize, compilation_timeout: u64, learn_timeout: u64, loss: Loss, optimizer: Optimizer, lr_drop: f64, epoch_drop: usize, early_stop_threshold: f64, early_stop_delta: f64, patience: usize) -> Self { + Self { lr, nepochs, compilation_timeout, learn_timeout, loss, optimizer, lr_drop, epoch_drop, early_stop_threshold, early_stop_delta, patience } } /// Returns the learning rate @@ -68,8 +68,8 @@ impl LearnParameters { } /// Returns the timeout used for the compilation of the queries (in seconds) - pub fn timeout(&self) -> u64 { - self.timeout + pub fn compilation_timeout(&self) -> u64 { + self.compilation_timeout } /// Return the timeout used for the learning loop (in seconds) diff --git a/src/learning/utils.rs b/src/learning/utils.rs index 42c9e2a..8f77487 100644 --- a/src/learning/utils.rs +++ b/src/learning/utils.rs @@ -17,7 +17,6 @@ use rug::Float; use std::path::PathBuf; use rayon::prelude::*; -use crate::branching::*; use crate::Branching; use search_trail::StateManager; use crate::core::components::ComponentExtractor; @@ -39,8 +38,7 @@ pub fn softmax(x: &[f64]) -> Vec { } /// Generates a vector of optional Dacs from a list of input files -pub fn generate_dacs(inputs: Vec, branching: Branching, epsilon: f64, timeout: u64) -> Vec>> - where R: SemiRing +pub fn generate_dacs(inputs: Vec, branching: Branching, epsilon: f64, timeout: u64) -> Vec>> { inputs.par_iter().map(|input| { // We compile the input. This can either be a .cnf file or a fdac file. @@ -49,7 +47,7 @@ pub fn generate_dacs(inputs: Vec, branching: Branching, epsilon: f64 FileType::CNF => { println!("Compiling {}", input.to_str().unwrap()); // The input is a CNF file, we need to compile it from scratch - let compiler = make_solver!(input, branching, epsilon, None, timeout, false); + let compiler = make_solver!(&input, branching, epsilon, None, timeout, false); compile!(compiler) }, FileType::FDAC => { diff --git a/src/lib.rs b/src/lib.rs index 0418acd..d47183a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,6 @@ use clap::ValueEnum; use rug::Float; use crate::core::components::ComponentExtractor; -use crate::branching::*; use solvers::ProblemSolution; use crate::solvers::*; use crate::parser::*; @@ -80,34 +79,48 @@ pub enum Optimizer { SGD, } +pub fn solve_from_problem(distributions: &Vec>, clauses: &Vec>, branching: Branching, epsilon: f64, memory: Option, timeout: u64, statistics: bool) -> ProblemSolution { + let solver = solver_from_problem!(distributions, clauses, branching, epsilon, memory, timeout, statistics); + search!(solver) +} + +pub fn search(input: PathBuf, branching: Branching, statistics: bool, memory: Option, epsilon: f64, timeout: u64) -> ProblemSolution { + let solver = make_solver!(&input, branching, epsilon, memory, timeout, statistics); + search!(solver) +} + +fn _compile(compiler: GenericSolver, fdac: Option, dotfile: Option) -> ProblemSolution { + let mut res: Option> = compile!(compiler); + if let Some(ref mut dac) = &mut res { + dac.evaluate(); + let proba = dac.circuit_probability().clone(); + if let Some(f) = dotfile { + let out = dac.as_graphviz(); + let mut outfile = File::create(f).unwrap(); + match outfile.write(out.as_bytes()) { + Ok(_) => (), + Err(e) => println!("Could not write the circuit into the dot file: {:?}", e), + } + } + if let Some(f) = fdac { + let mut outfile = File::create(f).unwrap(); + match outfile.write(format!("{}", dac).as_bytes()) { + Ok(_) => (), + Err(e) => println!("Could not write the circuit into the fdac file: {:?}", e), + } + + } + ProblemSolution::Ok(proba) + } else { + ProblemSolution::Err(Error::Timeout) + } +} + pub fn compile(input: PathBuf, branching: Branching, fdac: Option, dotfile: Option, epsilon: f64, timeout: u64) -> ProblemSolution { match type_of_input(&input) { FileType::CNF => { let compiler = make_solver!(&input, branching, epsilon, None, timeout, false); - let mut res: Option> = compile!(compiler); - if let Some(ref mut dac) = &mut res { - dac.evaluate(); - let proba = dac.circuit_probability().clone(); - if let Some(f) = dotfile { - let out = dac.as_graphviz(); - let mut outfile = File::create(f).unwrap(); - match outfile.write(out.as_bytes()) { - Ok(_) => (), - Err(e) => println!("Could not write the circuit into the dot file: {:?}", e), - } - } - if let Some(f) = fdac { - let mut outfile = File::create(f).unwrap(); - match outfile.write(format!("{}", dac).as_bytes()) { - Ok(_) => (), - Err(e) => println!("Could not write the circuit into the fdac file: {:?}", e), - } - - } - ProblemSolution::Ok(proba) - } else { - ProblemSolution::Err(Error::Timeout) - } + _compile(compiler, fdac, dotfile) }, FileType::FDAC => { let mut dac: Dac = Dac::::from_file(&input); @@ -118,21 +131,27 @@ pub fn compile(input: PathBuf, branching: Branching, fdac: Option, dotf } } +pub fn compile_from_problem(distributions: &Vec>, clauses: &Vec>, branching: Branching, epsilon: f64, memory: Option, timeout: u64, statistics: bool, fdac: Option, dotfile: Option) -> ProblemSolution { + let solver = solver_from_problem!(distributions, clauses, branching, epsilon, memory, timeout, statistics); + _compile(solver, fdac, dotfile) +} + + pub fn make_learner(inputs: Vec, expected: Vec, epsilon: f64, branching: Branching, outfolder: Option, jobs: usize, log: bool, semiring: Semiring, params: &LearnParameters, test_inputs:Vec, test_expected:Vec) -> Box { match semiring { Semiring::Probability => { if log { - Box::new(Learner::::new(inputs, expected, epsilon, branching, outfolder, jobs, params.timeout(), test_inputs, test_expected)) + Box::new(Learner::::new(inputs, expected, epsilon, branching, outfolder, jobs, params.compilation_timeout(), test_inputs, test_expected)) } else { - Box::new(Learner::::new(inputs, expected, epsilon, branching, outfolder, jobs, params.timeout(), test_inputs, test_expected)) + Box::new(Learner::::new(inputs, expected, epsilon, branching, outfolder, jobs, params.compilation_timeout(), test_inputs, test_expected)) } }, #[cfg(feature = "tensor")] Semiring::Tensor => { if log { - Box::new(TensorLearner::::new(inputs, expected, epsilon, branching, outfolder, jobs, params.timeout(), params.optimizer, test_inputs, test_expected)) + Box::new(TensorLearner::::new(inputs, expected, epsilon, branching, outfolder, jobs, params.compilation_timeout(), params.optimizer, test_inputs, test_expected)) } else { - Box::new(TensorLearner::::new(inputs, expected, epsilon, branching, outfolder, jobs, params.timeout(), params.optimizer, test_inputs, test_expected)) + Box::new(TensorLearner::::new(inputs, expected, epsilon, branching, outfolder, jobs, params.compilation_timeout(), params.optimizer, test_inputs, test_expected)) } } } @@ -167,12 +186,6 @@ pub fn learn(trainfile: PathBuf, testfile:Option, branching: Branching, learner.train(¶ms); } -pub fn search(input: PathBuf, branching: Branching, statistics: bool, memory: Option, epsilon: f64, timeout: u64) -> ProblemSolution { - let solver = make_solver!(&input, branching, epsilon, memory, timeout, statistics); - search!(solver) -} - - impl std::fmt::Display for Loss { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/src/main.rs b/src/main.rs index ab827c4..8d68f52 100644 --- a/src/main.rs +++ b/src/main.rs @@ -108,7 +108,7 @@ enum Command { #[clap(long, short, default_value_t=schlandals::Semiring::Probability, value_enum)] semiring: schlandals::Semiring, /// The optimizer to use if `tensor` is selected as semiring - #[clap(long, short, default_value_t=schlandals::Optimizer::SGD, value_enum)] + #[clap(long, short, default_value_t=schlandals::Optimizer::Adam, value_enum)] optimizer: schlandals::Optimizer, /// The drop in the learning rate to apply at each step #[clap(long, default_value_t=0.75)] diff --git a/src/parser.rs b/src/parser.rs index cb05550..fb3af38 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -54,6 +54,36 @@ pub enum FileType { FDAC, } +pub fn graph_from_problem(distributions: &Vec>, clauses: &Vec>, state: &mut StateManager) -> Graph { + let mut number_var = 0; + for clause in clauses.iter() { + number_var = number_var.max(clause.iter().map(|l| l.abs() as usize).max().unwrap()); + } + let mut g = Graph::new(state, number_var, clauses.len()); + g.add_distributions(&distributions, state); + for clause in clauses.iter() { + let mut literals: Vec = vec![]; + let mut head: Option = None; + for lit in clause.iter().copied() { + if lit == 0 { + panic!("Variables in clauses can not be 0"); + } + let var = VariableIndex(lit.abs() as usize - 1); + let trail_value_index = g[var].get_value_index(); + let literal = Literal::from_variable(var, lit > 0, trail_value_index); + if lit > 0 { + if head.is_some() { + panic!("The clauses {} has more than one positive literal", clause.iter().map(|i| format!("{}", i)).collect::>().join(" ")); + } + head = Some(literal); + } + literals.push(literal); + } + g.add_clause(literals, head, state, false); + } + g +} + pub fn graph_from_ppidimacs( filepath: &PathBuf, state: &mut StateManager, @@ -72,13 +102,9 @@ pub fn graph_from_ppidimacs( let number_var = split_header.nth(2).unwrap().parse::().unwrap(); let number_clauses = split_header.next().unwrap().parse::().unwrap(); - let mut number_probabilistic = 0; let mut g = Graph::new(state, number_var, number_clauses); - for d in distributions.iter() { - number_probabilistic += d.len(); - } g.add_distributions(&distributions, state); // Second pass to parse the clauses @@ -104,13 +130,7 @@ pub fn graph_from_ppidimacs( for lit in clause.split_whitespace() { let trail_value_index = g[VariableIndex(lit.parse::().unwrap().abs() as usize - 1)].get_value_index(); let literal = Literal::from_str(lit, trail_value_index); - if literal.to_variable().0 < number_probabilistic { - // The variable is probabilistic, put at the end of the vector - literals.push(literal); - } else { - // The variable is deterministic, put at the beginning of the vector - literals.insert(0, literal); - } + literals.push(literal); if literal.is_positive() { if head.is_some() { panic!("The clause {} has multiple positive literals", line); @@ -189,29 +209,3 @@ pub fn type_of_input(filepath: &PathBuf) -> FileType { panic!("Unexpected file format to read from. Header does not match .cnf or .fdac file: {}", header); } } -/* -#[cfg(test)] -mod test_ppidimacs_parsing { - - use super::graph_from_ppidimacs; - use crate::core::graph::VariableIndex; - use crate::core::trail::StateManager; - use std::path::PathBuf; - - #[test] - fn test_file() { - let mut file = PathBuf::new(); - let mut state = StateManager::default(); - file.push("tests/instances/bayesian_networks/abc_chain_b0.ppidimacs"); - let (g, _) = graph_from_ppidimacs(&file, &mut state); - // Nodes for the distributions, the deterministics + 1 node for the vb0 -> False - assert_eq!(17, g.number_nodes()); - assert_eq!(5, g.number_distributions()); - - let nodes: Vec = g.nodes_iter().collect(); - for i in 0..10 { - assert!(g.is_node_probabilistic(nodes[i])); - } - } -} -*/ diff --git a/src/solvers/mod.rs b/src/solvers/mod.rs index cd3c367..7d2b7b1 100644 --- a/src/solvers/mod.rs +++ b/src/solvers/mod.rs @@ -15,9 +15,14 @@ //along with this program. If not, see . -use rug::Float; +use crate::core::graph::Graph; use crate::branching::*; +use crate::core::components::ComponentExtractor; +use crate::propagator::Propagator; +use crate::Branching; +use search_trail::StateManager; +use rug::Float; use std::hash::Hash; use bitvec::prelude::*; @@ -48,6 +53,65 @@ pub enum GenericSolver { QVSIDS(Solver), } +pub fn generic_solver(graph: Graph, state: StateManager, component_extractor: ComponentExtractor, branching: Branching, propagator: Propagator, mlimit: u64, epsilon: f64, timeout: u64, stat: bool) -> GenericSolver { + if stat { + match branching { + Branching::MinInDegree => { + let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, epsilon, timeout); + GenericSolver::SMinInDegree(solver) + }, + Branching::MinOutDegree => { + let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, epsilon, timeout); + GenericSolver::SMinOutDegree(solver) + }, + Branching::MaxDegree => { + let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, epsilon, timeout); + GenericSolver::SMaxDegree(solver) + }, + Branching::VSIDS => { + let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, epsilon, timeout); + GenericSolver::SVSIDS(solver) + }, + } + } else { + match branching { + Branching::MinInDegree => { + let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, epsilon, timeout); + GenericSolver::QMinInDegree(solver) + }, + Branching::MinOutDegree => { + let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, epsilon, timeout); + GenericSolver::QMinOutDegree(solver) + }, + Branching::MaxDegree => { + let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, epsilon, timeout); + GenericSolver::QMaxDegree(solver) + }, + Branching::VSIDS => { + let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, epsilon, timeout); + GenericSolver::QVSIDS(solver) + }, + } + } +} + +macro_rules! solver_from_problem { + ($d:expr, $c:expr, $b:expr, $e:expr, $m:expr, $t:expr, $s:expr) => { + { + let mut state = StateManager::default(); + let graph = graph_from_problem($d, $c, &mut state); + let propagator = Propagator::new(&mut state); + let component_extractor = ComponentExtractor::new(&graph, &mut state); + let mlimit = if let Some(m) = $m { + m + } else { + u64::MAX + }; + generic_solver(graph, state, component_extractor, $b, propagator, mlimit, $e, $t, $s) + } + }; +} + macro_rules! make_solver { ($i:expr, $b:expr, $e:expr, $m:expr, $t: expr, $s:expr) => { { @@ -60,45 +124,7 @@ macro_rules! make_solver { } else { u64::MAX }; - if $s { - match $b { - Branching::MinInDegree => { - let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, $e, $t); - GenericSolver::SMinInDegree(solver) - }, - Branching::MinOutDegree => { - let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, $e, $t); - GenericSolver::SMinOutDegree(solver) - }, - Branching::MaxDegree => { - let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, $e, $t); - GenericSolver::SMaxDegree(solver) - }, - Branching::VSIDS => { - let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, $e, $t); - GenericSolver::SVSIDS(solver) - }, - } - } else { - match $b { - Branching::MinInDegree => { - let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, $e, $t); - GenericSolver::QMinInDegree(solver) - }, - Branching::MinOutDegree => { - let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, $e, $t); - GenericSolver::QMinOutDegree(solver) - }, - Branching::MaxDegree => { - let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, $e, $t); - GenericSolver::QMaxDegree(solver) - }, - Branching::VSIDS => { - let solver = Solver::::new(graph, state, component_extractor, Box::::default(), propagator, mlimit, $e, $t); - GenericSolver::QVSIDS(solver) - }, - } - } + generic_solver(graph, state, component_extractor, $b, propagator, mlimit, $e, $t, $s) } }; } @@ -133,6 +159,7 @@ macro_rules! compile { } } +pub(crate) use solver_from_problem; pub(crate) use make_solver; pub(crate) use compile; pub(crate) use search;