Skip to content

Commit

Permalink
update pyschlandals interface
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandreDubray committed Feb 9, 2024
1 parent d931ddd commit 0590943
Show file tree
Hide file tree
Showing 14 changed files with 435 additions and 204 deletions.
9 changes: 0 additions & 9 deletions pyschlandals/example.py

This file was deleted.

30 changes: 24 additions & 6 deletions pyschlandals/example/compile.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,26 @@
from pyschlandals import BranchingHeuristic
from pyschlandals.compiler import compile
from pyschlandals.pwmc import PyProblem

import sys
problem = PyProblem()
problem.add_distribution([0.2, 0.8])
problem.add_distribution([0.3, 0.7])
problem.add_distribution([0.4, 0.6])
problem.add_distribution([0.1, 0.9])
problem.add_distribution([0.5, 0.5])

if __name__ == '__main__':
dac = compile(sys.argv[1], BranchingHeuristic.MinInDegree)
print(dac.get_circuit_probability())
problem.add_clause([11, -1])
problem.add_clause([12, -2])
problem.add_clause([13, -11, -3])
problem.add_clause([13, -12, -5])
problem.add_clause([14, -11, -4])
problem.add_clause([14, -12, -6])
problem.add_clause([15, -13, -7])
problem.add_clause([15, -14, -9])
problem.add_clause([16, -13, -8])
problem.add_clause([16, -14, -10])
problem.add_clause([-15])

# For one shot compilation/evaluation, just use compile():
print(problem.compile())

# However, you might want to store the compiled AC or visualize it as a DOT file
problem.compile(fdac="out.ac", dotfile="ac.dot")
8 changes: 0 additions & 8 deletions pyschlandals/example/search.py

This file was deleted.

22 changes: 22 additions & 0 deletions pyschlandals/example/simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pyschlandals.pwmc import PyProblem

problem = PyProblem()
problem.add_distribution([0.2, 0.8])
problem.add_distribution([0.3, 0.7])
problem.add_distribution([0.4, 0.6])
problem.add_distribution([0.1, 0.9])
problem.add_distribution([0.5, 0.5])

problem.add_clause([11, -1])
problem.add_clause([12, -2])
problem.add_clause([13, -11, -3])
problem.add_clause([13, -12, -5])
problem.add_clause([14, -11, -4])
problem.add_clause([14, -12, -6])
problem.add_clause([15, -13, -7])
problem.add_clause([15, -14, -9])
problem.add_clause([16, -13, -8])
problem.add_clause([16, -14, -10])
problem.add_clause([-15])

print(problem.solve())
5 changes: 5 additions & 0 deletions pyschlandals/example/train.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Instance,Probability
train_input/asia_xray_true.cnf,0.11029
train_input/asia_xray_false.cnf,0.88971
train_input/asia_dyspnea_true.cnf,0.435971
train_input/asia_dyspnea_false.cnf,0.564029
4 changes: 4 additions & 0 deletions pyschlandals/example/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from pyschlandals.learn import learn, PyLearnParameters

params = PyLearnParameters()
learn("train.csv", params)
137 changes: 78 additions & 59 deletions pyschlandals/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@ use pyo3::Python;
use std::path::PathBuf;
use schlandals::*;

mod train;
use train::*;

#[pyclass]
#[derive(Clone)]
enum PyBranching {
#[derive(Clone, Copy)]
pub enum PyBranching {
MinInDegree,
MinOutDegree,
MaxDegree,
Expand All @@ -21,100 +24,116 @@ fn get_branching_from_pybranching(branching: PyBranching) -> Branching {

#[pyclass]
#[derive(Clone)]
enum PyLoss {
MAE,
MSE,
struct PyProblem {
distributions: Vec<Vec<f64>>,
clauses: Vec<Vec<isize>>,
branching: PyBranching,
epsilon: f64,
timeout: u64,
memory_limit: u64,
statistics: bool,
}

fn get_loss_from_pyloss(loss: PyLoss) -> Loss {
match loss {
PyLoss::MAE => Loss::MAE,
PyLoss::MSE => Loss::MSE,
#[pymethods]
impl PyProblem {

#[new]
pub fn new(branching: Option<PyBranching>, epsilon: Option<f64>, timeout: Option<u64>, memory_limit: Option<u64>, statistics: Option<bool>) -> Self {
Self {
distributions: vec![],
clauses: vec![],
branching: if let Some(b) = branching { b } else { PyBranching::MinInDegree },
epsilon: if let Some(e) = epsilon { e } else { 0.0 },
timeout: if let Some(to) = timeout { to } else { u64::MAX },
memory_limit: if let Some(limit) = memory_limit { limit } else { u64::MAX },
statistics: if let Some(s) = statistics { s } else { false },
}
}
}

// TODO: Find how to make the python binding to take into account that the tensors are a feature
// that is not enabled by default
#[pyclass]
#[derive(Clone)]
enum PySemiring {
Probability,
//Tensor,
}
pub fn add_distribution(&mut self, distribution: Vec<f64>) {
self.distributions.push(distribution);
}

fn get_semiring_from_pysemiring(semiring: PySemiring) -> Semiring {
match semiring {
PySemiring::Probability => Semiring::Probability,
//PySemiring::Tensor => Semiring::Tensor,
pub fn add_clause(&mut self, clause: Vec<isize>) {
self.clauses.push(clause);
}
}

#[pyclass]
#[derive(Clone)]
enum PyOptimizer {
SGD,
Adam,
}
pub fn solve(&self) -> Option<f64> {
match schlandals::solve_from_problem(&self.distributions, &self.clauses, get_branching_from_pybranching(self.branching), self.epsilon, Some(self.memory_limit), self.timeout, self.statistics) {
Ok(p) => Some(p.to_f64()),
Err(e) => {
println!("{:?}", e);
None
}
}
}

fn get_optimizer_from_pyoptimizer(optimizer: PyOptimizer) -> Optimizer {
match optimizer {
PyOptimizer::SGD => Optimizer::SGD,
PyOptimizer::Adam => Optimizer::Adam,
pub fn compile(&self, fdac: Option<String>, dotfile: Option<String>) -> Option<f64> {
match schlandals::compile_from_problem(&self.distributions,
&self.clauses,
get_branching_from_pybranching(self.branching),
self.epsilon,
Some(self.memory_limit),
self.timeout,
self.statistics,
if let Some(path) = fdac { Some(PathBuf::from(path)) } else { None },
if let Some(path) = dotfile { Some(PathBuf::from(path)) } else { None },) {
Ok(p) => Some(p.to_f64()),
Err(e) => {
println!("{:?}", e);
None
}
}
}
}


#[pyfunction]
#[pyo3(name = "search")]
fn pysearch(file: String, branching: PyBranching, epsilon: Option<f64>, memory_limit: Option<u64>) -> Option<f64> {
fn pysearch(file: String, branching: PyBranching, epsilon: Option<f64>, memory_limit: Option<u64>, timeout: Option<u64>) -> Option<f64> {
let e = if epsilon.is_none() {
0.0
} else {
epsilon.unwrap()
};
match schlandals::search(PathBuf::from(file), get_branching_from_pybranching(branching), false, memory_limit, e) {
let to = if timeout.is_none() { u64::MAX } else { timeout.unwrap() };
match schlandals::search(PathBuf::from(file), get_branching_from_pybranching(branching), false, memory_limit, e, to) {
Err(_) => None,
Ok(p) => Some(p.to_f64()),
}
}

#[pyfunction]
#[pyo3(name = "compile")]
fn pycompile(file: String, branching: PyBranching, epsilon: Option<f64>, output_circuit: Option<String>, output_dot: Option<String>) -> Option<f64> {
fn pycompile(file: String, branching: PyBranching, epsilon: Option<f64>, output_circuit: Option<String>, output_dot: Option<String>, timeout: Option<u64>) -> Option<f64> {
let fdac = if let Some(file) = output_circuit { Some(PathBuf::from(file)) } else { None };
let fdot = if let Some(file) = output_dot { Some(PathBuf::from(file)) } else { None };
let e = if let Some(e) = epsilon { e } else { 0.0 };
match schlandals::compile(PathBuf::from(file), get_branching_from_pybranching(branching), fdac, fdot, e) {
let to = if timeout.is_none() { u64::MAX } else { timeout.unwrap() };
match schlandals::compile(PathBuf::from(file), get_branching_from_pybranching(branching), fdac, fdot, e, to) {
Err(_) => None,
Ok(p) => Some(p.to_f64()),
}
}

#[pyfunction]
#[pyo3(name = "learn")]
fn pylearn(train_file: String, branching: PyBranching, learning_rate: f64, nepochs: usize, log: bool, timeout: u64, epsilon: f64, loss: PyLoss, jobs: usize, semiring: PySemiring, optimizer: PyOptimizer, test_file: Option<String>, outfolder: Option<PathBuf>) {
let b = get_branching_from_pybranching(branching);
let l = get_loss_from_pyloss(loss);
let s = get_semiring_from_pysemiring(semiring);
let o = get_optimizer_from_pyoptimizer(optimizer);
let train = PathBuf::from(train_file);
let test = if test_file.is_none() {
None
} else {
Some(PathBuf::from(test_file.unwrap()))
};
schlandals::learn(train, test, b, outfolder, learning_rate, nepochs, log, timeout, epsilon, l, jobs, s, o);
}
#[pymodule]
#[pyo3(name="pwmc")]
fn pwmc_submodule(py: Python<'_>, parent_module: &PyModule) -> PyResult<()> {
let module = PyModule::new(py, "pwmc")?;
module.add_class::<PyProblem>()?;
module.add_function(wrap_pyfunction!(pycompile, module)?)?;
module.add_function(wrap_pyfunction!(pysearch, module)?)?;

parent_module.add_submodule(module)?;
py.import("sys")?.getattr("modules")?.set_item("pyschlandals.pwmc", module)?;
Ok(())
}

/// Base module for pyschlandals
#[pymodule]
fn pyschlandals(_py: Python, m: &PyModule) -> PyResult<()> {
fn pyschlandals(py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyBranching>()?;
m.add_class::<PyLoss>()?;
m.add_class::<PyOptimizer>()?;
m.add_class::<PySemiring>()?;
m.add_function(wrap_pyfunction!(pylearn, m)?).unwrap();
m.add_function(wrap_pyfunction!(pysearch, m)?).unwrap();
m.add_function(wrap_pyfunction!(pycompile, m)?).unwrap();
pwmc_submodule(py, m)?;
train::learn_submodule(py, m)?;
Ok(())
}
Loading

0 comments on commit 0590943

Please sign in to comment.