From 815c1332a5e99f3bfc4c74148d70d6426cd1db99 Mon Sep 17 00:00:00 2001 From: Alexandre Dubray Date: Fri, 22 Sep 2023 16:04:11 +0200 Subject: [PATCH] [pyschlandals] get output of distribution nodes --- pyschlandals/src/lib.rs | 11 ++ src/compiler/circuit.rs | 13 +- src/compiler/exact.rs | 44 +++-- src/core/components.rs | 25 +-- src/core/graph.rs | 40 +++++ src/heuristics/branching_approximate.rs | 53 ++++++ .../{branching.rs => branching_exact.rs} | 157 +----------------- src/heuristics/mod.rs | 23 ++- src/lib.rs | 9 +- src/search/approximate.rs | 2 +- src/search/sequential.rs | 2 +- 11 files changed, 185 insertions(+), 194 deletions(-) create mode 100644 src/heuristics/branching_approximate.rs rename src/heuristics/{branching.rs => branching_exact.rs} (56%) diff --git a/pyschlandals/src/lib.rs b/pyschlandals/src/lib.rs index addfdf5..4260038 100644 --- a/pyschlandals/src/lib.rs +++ b/pyschlandals/src/lib.rs @@ -161,6 +161,17 @@ impl PyDac { self.dac.get_distribution_domain_size(DistributionNodeIndex(distribution)) } + /// Returns the pair (circuit node, value index) of the output of the distribution at its given output-index + pub fn get_distribution_node_output_at(&self, distribution: usize, index: usize) -> (usize, usize) { + let x = self.dac.get_distribution_output_at(DistributionNodeIndex(distribution), index); + (x.0.0, x.1) + } + + /// Returns the number of output of a distribution node + pub fn get_distribution_number_output(&self, distribution: usize) -> usize { + self.dac.get_distribution_number_output(DistributionNodeIndex(distribution)) + } + /// Returns the probability, of the given distribution, at the given index pub fn get_distribution_probability(&self, distribution: usize, probability_index: usize) -> f64 { self.dac.get_distribution_probability_at(DistributionNodeIndex(distribution), probability_index) diff --git a/src/compiler/circuit.rs b/src/compiler/circuit.rs index 5546ac4..f4a1a37 100644 --- a/src/compiler/circuit.rs +++ b/src/compiler/circuit.rs @@ -70,6 +70,7 @@ pub struct CircuitNode { /// A distribution node, an input of the circuit. Each distribution node holds the distribution's parameter as well as the outputs. /// For each output node, it also stores the value that must be sent to the output (as an index of the probability vector). struct DistributionNode { + /// Probabilities of the distribution probabilities: Vec, /// Outputs of the node @@ -157,7 +158,7 @@ impl Dac { /// Adds `output` to the outputs of `node` and `node` to the inputs of `output`. Note that this /// function uses the vectors in each node. They are transferred afterward in the `outputs` vector. - pub fn add_spnode_output(&mut self, node: CircuitNodeIndex, output: CircuitNodeIndex) { + pub fn add_circuit_node_output(&mut self, node: CircuitNodeIndex, output: CircuitNodeIndex) { self.nodes[node.0].outputs.push(output); self.nodes[node.0].number_outputs += 1; self.nodes[output.0].inputs.insert(node); @@ -509,6 +510,16 @@ impl Dac { self.distribution_nodes[distribution.0].probabilities[index] } + /// Returns the pair (circuit_node, index) for the output of the distribution at the given index + pub fn get_distribution_output_at(&self, distribution: DistributionNodeIndex, index: usize) -> (CircuitNodeIndex, usize) { + self.distribution_nodes[distribution.0].outputs[index] + } + + /// Returns the number of output of a distribution node + pub fn get_distribution_number_output(&self, distribution: DistributionNodeIndex) -> usize { + self.distribution_nodes[distribution.0].outputs.len() + } + // --- SETTERS --- // /// Set the probability of the distribution, at the given index, to the given value diff --git a/src/compiler/exact.rs b/src/compiler/exact.rs index 8a39b90..828e312 100644 --- a/src/compiler/exact.rs +++ b/src/compiler/exact.rs @@ -30,7 +30,7 @@ use search_trail::{StateManager, SaveAndRestore}; use crate::core::components::{ComponentExtractor, ComponentIndex}; use crate::core::graph::*; -use crate::heuristics::branching::BranchingDecision; +use crate::heuristics::BranchingDecision; use crate::propagator::CompiledPropagator; use crate::common::*; use crate::compiler::circuit::*; @@ -77,16 +77,14 @@ where } } - fn expand_sum_node(&mut self, spn: &mut Dac, component: ComponentIndex, distribution: DistributionIndex) -> Option { + fn expand_sum_node(&mut self, dac: &mut Dac, component: ComponentIndex, distribution: DistributionIndex) -> Option { let mut children: Vec = vec![]; for variable in self.graph.distribution_variable_iter(distribution) { self.state.save_state(); match self.propagator.propagate_variable(variable, true, &mut self.graph, &mut self.state, component, &self.component_extractor) { - Err(_) => { - - }, + Err(_) => { }, Ok(_) => { - if let Some(child) = self.expand_prod_node(spn, component) { + if let Some(child) = self.expand_prod_node(dac, component) { children.push(child); } } @@ -94,9 +92,9 @@ where self.state.restore_state(); } if !children.is_empty() { - let node = spn.add_sum_node(); + let node = dac.add_sum_node(); for child in children { - spn.add_spnode_output(child, node); + dac.add_circuit_node_output(child, node); } Some(node) } else { @@ -104,26 +102,26 @@ where } } - fn expand_prod_node(&mut self, spn: &mut Dac, component: ComponentIndex) -> Option { + fn expand_prod_node(&mut self, dac: &mut Dac, component: ComponentIndex) -> Option { let mut prod_node: Option = if self.propagator.has_assignments() || self.propagator.has_unconstrained_distribution() { - let node = spn.add_prod_node(); + let node = dac.add_prod_node(); for (distribution, variable, value) in self.propagator.assignments_iter() { if value { let value_id = variable.0 - self.graph.get_distribution_start(distribution).0; - spn.add_distribution_output(distribution, node, value_id); + dac.add_distribution_output(distribution, node, value_id); } } for distribution in self.propagator.unconstrained_distributions_iter() { if self.graph.distribution_number_false(distribution, &self.state) != 0 { - let sum_node = spn.add_sum_node(); + let sum_node = dac.add_sum_node(); for variable in self.graph.distribution_variable_iter(distribution) { if !self.graph.is_variable_fixed(variable, &self.state) { let value_id = variable.0 - self.graph.get_distribution_start(distribution).0; - spn.add_distribution_output(distribution, sum_node, value_id); + dac.add_distribution_output(distribution, sum_node, value_id); } } - spn.add_spnode_output(sum_node, node); + dac.add_circuit_node_output(sum_node, node); } } Some(node) @@ -137,7 +135,7 @@ where match self.cache.get(&bit_repr) { None => { if let Some(distribution) = self.branching_heuristic.branch_on(&self.graph, &self.state, &self.component_extractor, sub_component) { - if let Some(child) = self.expand_sum_node(spn, sub_component, distribution) { + if let Some(child) = self.expand_sum_node(dac, sub_component, distribution) { sum_children.push(child); self.cache.insert(bit_repr, Some(child)); } else { @@ -166,11 +164,11 @@ where } } if !sum_children.is_empty() && prod_node.is_none() { - prod_node = Some(spn.add_prod_node()); + prod_node = Some(dac.add_prod_node()); } if let Some(node) = prod_node { for child in sum_children { - spn.add_spnode_output(child, node); + dac.add_circuit_node_output(child, node); } } prod_node @@ -185,14 +183,14 @@ where Err(_) => None, Ok(_) => { self.branching_heuristic.init(&self.graph, &self.state); - let mut spn = Dac::new(&self.graph); - match self.expand_prod_node(&mut spn, ComponentIndex(0)) { + let mut dac = Dac::new(&self.graph); + match self.expand_prod_node(&mut dac, ComponentIndex(0)) { None => None, Some(_) => { - spn.remove_dead_ends(); - spn.reduce(); - spn.layerize(); - Some(spn) + dac.remove_dead_ends(); + dac.reduce(); + dac.layerize(); + Some(dac) } } } diff --git a/src/core/components.rs b/src/core/components.rs index 702d786..c263694 100644 --- a/src/core/components.rs +++ b/src/core/components.rs @@ -57,7 +57,7 @@ pub struct ComponentExtractor { /// vector clauses: Vec, /// The vector mapping for each `ClauseIndex` its position in `clauses` - positions: Vec, + clause_positions: Vec, /// Holds the components computed by the extractor during the search components: Vec, /// The index of the first component of the current node in the search tree @@ -110,7 +110,7 @@ impl ComponentExtractor { /// Creates a new component extractor for the implication graph `g` pub fn new(g: &Graph, state: &mut StateManager) -> Self { let nodes = (0..g.number_clauses()).map(ClauseIndex).collect(); - let positions = (0..g.number_clauses()).collect(); + let clause_positions = (0..g.number_clauses()).collect(); let distributions = (0..g.number_distributions()).map(DistributionIndex).collect(); let distribution_positions = (0..g.number_distributions()).collect(); let components = vec![Component { @@ -122,7 +122,7 @@ impl ComponentExtractor { }]; Self { clauses: nodes, - positions, + clause_positions, components, base: state.manage_usize(0), limit: state.manage_usize(1), @@ -146,7 +146,7 @@ impl ComponentExtractor { ) -> bool { // if the clause has already been visited, then its position in the component must // be between [start..(start + size)]. - let clause_pos = self.positions[clause.0]; + let clause_pos = self.clause_positions[clause.0]; g.is_clause_constrained(clause, state) && !(comp_start <= clause_pos && clause_pos < (comp_start + *comp_size)) } @@ -171,15 +171,15 @@ impl ComponentExtractor { if self.is_node_visitable(g, clause, comp_start, comp_size, state) { *hash ^= g.get_clause_random(clause); // The clause is swap with the clause at position comp_sart + comp_size - let current_pos = self.positions[clause.0]; + let current_pos = self.clause_positions[clause.0]; let new_pos = comp_start + *comp_size; // Only move the nodes if it is not already in position // Not sure if this optimization is worth in practice if new_pos != current_pos { let moved_node = self.clauses[new_pos]; self.clauses.as_mut_slice().swap(new_pos, current_pos); - self.positions[clause.0] = new_pos; - self.positions[moved_node.0] = current_pos; + self.clause_positions[clause.0] = new_pos; + self.clause_positions[moved_node.0] = current_pos; } *comp_size += 1; @@ -329,6 +329,7 @@ impl ComponentExtractor { let end = start + self.components[component.0].number_distribution; self.distributions[start..end].iter().copied() } + } /// This structure is used to implement a simple component detector that always returns one @@ -433,7 +434,7 @@ mod test_component_detection { let mut state = StateManager::default(); let g = get_graph(&mut state); let extractor = ComponentExtractor::new(&g, &mut state); - assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions); + assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions); check_component(&extractor, 0, 6, (0..6).map(ClauseIndex).collect::>()); assert_eq!(0, state.get_usize(extractor.base)); assert_eq!(1, state.get_usize(extractor.limit)); @@ -450,7 +451,7 @@ mod test_component_detection { extractor.detect_components(&mut g, &mut state, ComponentIndex(0), &mut propagator); assert_eq!(2, extractor.number_components(&state)); - assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions); + assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions); check_component(&extractor, 0, 4, (0..4).map(ClauseIndex).collect::>()); check_component(&extractor, 5, 6, vec![ClauseIndex(5)]); @@ -473,7 +474,7 @@ mod test_component_detection { extractor.detect_components(&mut g, &mut state, ComponentIndex(0), &mut propagator); assert_eq!(2, extractor.number_components(&state)); - assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions); + assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions); check_component(&extractor, 0, 4, (0..4).map(ClauseIndex).collect::>()); check_component(&extractor, 5, 6, vec![ClauseIndex(5)]); @@ -487,13 +488,13 @@ mod test_component_detection { state.restore_state(); assert_eq!(2, extractor.number_components(&state)); - assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions); + assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions); check_component(&extractor, 0, 4, (0..4).map(ClauseIndex).collect::>()); check_component(&extractor, 5, 6, vec![ClauseIndex(5)]); state.restore_state(); - assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions); + assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions); check_component(&extractor, 0, 6, (0..6).map(ClauseIndex).collect::>()); } } \ No newline at end of file diff --git a/src/core/graph.rs b/src/core/graph.rs index f77b89e..32e46d4 100644 --- a/src/core/graph.rs +++ b/src/core/graph.rs @@ -642,6 +642,46 @@ impl Graph { } } + fn get_distribution_maximum(&self, distribution: DistributionIndex, state: &StateManager) -> f64 { + let mut max = 0.0; + for v in self.distribution_variable_iter(distribution) { + if !self.is_variable_fixed(v, state) { + let proba = self.get_variable_weight(v).unwrap(); + if max < proba { + max = proba; + } + } + } + max + } + + /// Returns the distribution with the unassgined variable with the highest probability. + pub fn get_clause_active_distribution_highest_value(&self, clause: ClauseIndex, state: &StateManager) -> Option<(DistributionIndex, f64)> { + let number_probabilistic_watched = self.get_clause_body_bounds_probabilistic(clause, state); + if number_probabilistic_watched == 0 { + None + } else if number_probabilistic_watched == 1 { + let d = self.get_variable_distribution(self.clauses[clause.0].body_probabilistic[0]).unwrap(); + let proba = self.get_distribution_maximum(d, state); + Some((d, proba)) + } else { + let d1 = self.get_variable_distribution(self.clauses[clause.0].body_probabilistic[0]).unwrap(); + let d2 = self.get_variable_distribution(self.clauses[clause.0].body_probabilistic[1]).unwrap(); + if d1 == d2 { + let proba_1 = self.get_distribution_maximum(d1, state); + Some((d1, proba_1)) + } else { + let proba_1 = self.get_distribution_maximum(d1, state); + let proba_2 = self.get_distribution_maximum(d2, state); + if proba_1 < proba_2 { + Some((d2, proba_2)) + } else { + Some((d1, proba_1)) + } + } + } + } + /// Returns the number of constrained parents of the clause pub fn get_clause_number_parents(&self, clause: ClauseIndex, state: &StateManager) -> usize { self.clauses[clause.0].parents.len(state) diff --git a/src/heuristics/branching_approximate.rs b/src/heuristics/branching_approximate.rs new file mode 100644 index 0000000..ab81e0a --- /dev/null +++ b/src/heuristics/branching_approximate.rs @@ -0,0 +1,53 @@ +//Schlandals +//Copyright (C) 2022-2023 A. Dubray +// +//This program is free software: you can redistribute it and/or modify +//it under the terms of the GNU Affero General Public License as published by +//the Free Software Foundation, either version 3 of the License, or +//(at your option) any later version. +// +//This program is distributed in the hope that it will be useful, +//but WITHOUT ANY WARRANTY; without even the implied warranty of +//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +//GNU Affero General Public License for more details. +// +//You should have received a copy of the GNU Affero General Public License +//along with this program. If not, see . + +use search_trail::StateManager; +use crate::core::components::{ComponentExtractor, ComponentIndex}; +use crate::core::graph::{DistributionIndex, Graph}; +use crate::heuristics::BranchingDecision; + + +#[derive(Default)] +pub struct MaxProbability; + +impl BranchingDecision for MaxProbability { + fn branch_on( + &mut self, + g: &Graph, + state: &StateManager, + component_extractor: &ComponentExtractor, + component: ComponentIndex, + ) -> Option { + let mut best_score = usize::MAX; + let mut best_distribution: Option = None; + let mut best_tie = 0.0; + for clause in component_extractor.component_iter(component) { + if g.is_clause_constrained(clause, state) && g.clause_has_probabilistic(clause, state) { + let score = g.get_clause_number_parents(clause, state); + let (d, proba) = g.get_clause_active_distribution_highest_value(clause, state).unwrap(); + if score < best_score || (score == best_score && proba > best_tie) { + best_score = score; + best_tie = proba; + best_distribution = Some(d); + } + } + } + best_distribution + } + + fn init(&mut self, _g: &Graph, _state: &StateManager) {} + +} \ No newline at end of file diff --git a/src/heuristics/branching.rs b/src/heuristics/branching_exact.rs similarity index 56% rename from src/heuristics/branching.rs rename to src/heuristics/branching_exact.rs index 3b5fb53..61cadff 100644 --- a/src/heuristics/branching.rs +++ b/src/heuristics/branching_exact.rs @@ -28,142 +28,8 @@ use search_trail::StateManager; use crate::core::components::{ComponentExtractor, ComponentIndex}; use crate::core::graph::{ClauseIndex, DistributionIndex, Graph}; -use nalgebra::DMatrix; +use crate::heuristics::BranchingDecision; -/// Trait that defined the methods that a branching decision structure must implement. -pub trait BranchingDecision { - /// Chooses one distribution from the component to branch on and returns it. If no distribution is present in - /// the component, returns None. - fn branch_on( - &mut self, - g: &Graph, - state: &StateManager, - component_extractor: &ComponentExtractor, - component: ComponentIndex, - ) -> Option; - - /// Initialize, if necessary, the data structures used by the branching heuristics - fn init(&mut self, g: &Graph, state: &StateManager); -} - -/// A Fiedler-based branching heuristics. The Fiedler vector of a graph is the eigenvector -/// associated with the second smallest eigenvalue of the Laplacian matrix of a graph. -/// This vector gives a value to each node of the graph depending on its position in it. -/// A node on the boundary of the graph has a large value (positive or negative), and a node in the -/// center will have a value close to 0. -/// -/// This heuristics computes the fiedler vector for the implication graph of the clauses. Since computing the -/// fiedler vector is computationnaly heavy, it is done only at the beginning, during the initialization. -/// Then, we the branching decision must be selected, the mean fiedler values of all the clauses in the component -/// is computed, and the clause with the value closest to the mean is selected. -/// This means that a clause in the "center" of a component is selected. -/// Then, it selects the first unfixed distribution from the clause. -#[derive(Default)] -pub struct Fiedler { - fiedler_values: Vec, -} - -impl BranchingDecision for Fiedler { - - fn init(&mut self, g: &Graph, state: &StateManager) { - let mut lp_idx: Vec = (0..g.number_clauses()).collect(); - let mut cur_idx = 0; - for clause in g.clause_iter() { - if g.is_clause_constrained(clause, state) { - lp_idx[clause.0] = cur_idx; - cur_idx += 1; - } - } - - if cur_idx == 0 { - return; - } - - // Computation of the laplacian matrix of the implication graph. This is a square matrix L in which we have - // - L[i,i] = the degree of clause i - // - L[i, j] = -1 if clause i and j are connected (i != j) - // - // We assume that i and j are linked if we have either i -> j or j -> i - let mut laplacian = DMatrix::from_element(cur_idx, cur_idx, 0.0); - - for clause in g.clause_iter() { - // We only consider constrained clauses, to avoid noise from unnecessary clauses. - if g.is_clause_constrained(clause, state) { - for parent in g.parents_clause_iter(clause, state) { - if g.is_clause_constrained(parent, state) { - laplacian[(lp_idx[clause.0], lp_idx[clause.0])] += 0.5; - laplacian[(lp_idx[parent.0], lp_idx[parent.0])] += 0.5; - laplacian[(lp_idx[clause.0], lp_idx[parent.0])] = -1.0; - } - } - for child in g.children_clause_iter(clause, state) { - if g.is_clause_constrained(child, state) { - laplacian[(lp_idx[clause.0], lp_idx[clause.0])] += 0.5; - laplacian[(lp_idx[child.0], lp_idx[child.0])] += 0.5; - laplacian[(lp_idx[clause.0], lp_idx[child.0])] = -1.0; - } - } - } - } - - // Computing the eigenvectors - let decomp = laplacian.hermitian_part().symmetric_eigen(); - let mut smallest = (f64::INFINITY, f64::INFINITY); - let mut indexes = (0, 0); - for i in 0..cur_idx { - let eigenvalue = decomp.eigenvalues[i]; - if eigenvalue < smallest.0 { - smallest.1 = smallest.0; - indexes.1 = indexes.0; - smallest.0 = eigenvalue; - indexes.0 = i; - } else if eigenvalue < smallest.1 { - smallest.1 = eigenvalue; - indexes.1 = i; - } - } - self.fiedler_values = (0..g.number_clauses()).map(|i| { - if g.is_clause_constrained(ClauseIndex(i), state) { - decomp.eigenvectors.row(lp_idx[i])[indexes.1] - } else { - 0.0 - } - }).collect::>(); - } - - fn branch_on(&mut self, g: &Graph, state: &StateManager, component_extractor: &ComponentExtractor, component: ComponentIndex) -> Option { - let mut best_clause: Option = None; - let mut sum_fiedler = 0.0; - let mut count = 1.0; - for clause in component_extractor.component_iter(component) { - if g.is_clause_constrained(clause, state) { - sum_fiedler += self.fiedler_values[clause.0]; - count += 1.0; - } - } - if count == 0.0 { - return None; - } - let mean_fiedler = sum_fiedler / count; - let mut best_score = f64::MAX; - for clause in component_extractor.component_iter(component) { - if g.is_clause_constrained(clause, state) && g.clause_has_probabilistic(clause, state) { - let score = (self.fiedler_values[clause.0] - mean_fiedler).abs(); - if score < best_score { - best_score = score; - best_clause = Some(clause); - } - } - } - match best_clause { - Some(clause) => { - debug_assert!(g.clause_has_probabilistic(clause, state)); - g.get_clause_active_distribution(clause, state) - }, - None => None - } - } -} /// This heuristic selects the clause with the minimum in degree. In case of tie, it selects the clause /// for which the less number of parents have been removed. @@ -289,7 +155,8 @@ mod test_heuristics { use crate::core::graph::{Graph, VariableIndex, DistributionIndex}; use crate::core::components::ComponentExtractor; - use crate::heuristics::branching::*; + use crate::core::components::ComponentIndex; + use crate::heuristics::branching_exact::*; use search_trail::StateManager; // Graph used for the tests: @@ -322,24 +189,6 @@ mod test_heuristics { g } - #[test] - fn test_fiedler() { - let mut state = StateManager::default(); - let g = get_graph(&mut state); - let extractor = ComponentExtractor::new(&g, &mut state); - let mut branching = Fiedler::default(); - branching.init(&g, &state); - let expected_fiedler = vec![-0.506285, -00.297142, -0.216295, -0.0460978, 0.394186, 0.671633]; - println!("{:?}", branching.fiedler_values); - for i in 0..expected_fiedler.len() { - assert!((expected_fiedler[i] - branching.fiedler_values[i]).abs() <= 0.00001); - } - let decision = branching.branch_on(&g, &state, &extractor, ComponentIndex(0)); - assert!(decision.is_some()); - assert_eq!(DistributionIndex(3), decision.unwrap()); - - } - #[test] fn test_min_in_degree() { let mut state = StateManager::default(); diff --git a/src/heuristics/mod.rs b/src/heuristics/mod.rs index dc9280f..7cbfc7d 100644 --- a/src/heuristics/mod.rs +++ b/src/heuristics/mod.rs @@ -14,4 +14,25 @@ //You should have received a copy of the GNU Affero General Public License //along with this program. If not, see . -pub mod branching; \ No newline at end of file +use search_trail::StateManager; +use crate::core::components::{ComponentExtractor, ComponentIndex}; +use crate::core::graph::{DistributionIndex, Graph}; + +/// Trait that defined the methods that a branching decision structure must implement. +pub trait BranchingDecision { + /// Chooses one distribution from the component to branch on and returns it. If no distribution is present in + /// the component, returns None. + fn branch_on( + &mut self, + g: &Graph, + state: &StateManager, + component_extractor: &ComponentExtractor, + component: ComponentIndex, + ) -> Option; + + /// Initialize, if necessary, the data structures used by the branching heuristics + fn init(&mut self, g: &Graph, state: &StateManager); +} + +pub mod branching_exact; +pub mod branching_approximate; \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 89e7c0c..9668549 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,7 +23,9 @@ use rug::Float; use clap::ValueEnum; use crate::core::components::ComponentExtractor; -use crate::heuristics::branching::*; +use crate::heuristics::BranchingDecision; +use crate::heuristics::branching_exact::*; +use crate::heuristics::branching_approximate::*; use crate::search::{ExactDefaultSolver, ExactQuietSolver, ApproximateDefaultSolver, ApproximateQuietSolver}; use crate::propagator::{SearchPropagator, CompiledPropagator, MixedPropagator}; use crate::compiler::exact::ExactDACCompiler; @@ -50,6 +52,8 @@ pub enum Branching { MinOutDegree, /// Maximum degree of a clause in the implication-graph MaxDegree, + /// Select the distribution with the non-assigned variable that has the highest probability + MaxProbability, } pub fn compile(input: PathBuf, branching: Branching, fdac: Option, dotfile: Option) -> Option { @@ -61,6 +65,7 @@ pub fn compile(input: PathBuf, branching: Branching, fdac: Option, dotf Branching::MinInDegree => Box::::default(), Branching::MinOutDegree => Box::::default(), Branching::MaxDegree => Box::::default(), + Branching::MaxProbability => Box::::default(), }; let mut compiler = ExactDACCompiler::new(graph, state, component_extractor, branching_heuristic.as_mut(), propagator); let mut res = compiler.compile(); @@ -108,6 +113,7 @@ pub fn approximate_search(input: PathBuf, branching: Branching, statistics: bool Branching::MinInDegree => Box::::default(), Branching::MinOutDegree => Box::::default(), Branching::MaxDegree => Box::::default(), + Branching::MaxProbability => Box::::default(), }; let mlimit = if let Some(m) = memory { m @@ -149,6 +155,7 @@ pub fn search(input: PathBuf, branching: Branching, statistics: bool, memory: Op Branching::MinInDegree => Box::::default(), Branching::MinOutDegree => Box::::default(), Branching::MaxDegree => Box::::default(), + Branching::MaxProbability => Box::::default(), }; let mlimit = if let Some(m) = memory { m diff --git a/src/search/approximate.rs b/src/search/approximate.rs index eef2d6e..2cead70 100644 --- a/src/search/approximate.rs +++ b/src/search/approximate.rs @@ -29,7 +29,7 @@ use search_trail::{StateManager, SaveAndRestore}; use crate::core::components::{ComponentExtractor, ComponentIndex}; use crate::core::graph::*; -use crate::heuristics::branching::BranchingDecision; +use crate::heuristics::BranchingDecision; use crate::propagator::MixedPropagator; use crate::search::statistics::Statistics; use crate::common::*; diff --git a/src/search/sequential.rs b/src/search/sequential.rs index 1845bcf..a4302b1 100644 --- a/src/search/sequential.rs +++ b/src/search/sequential.rs @@ -29,7 +29,7 @@ use search_trail::{StateManager, SaveAndRestore}; use crate::core::components::{ComponentExtractor, ComponentIndex}; use crate::core::graph::*; -use crate::heuristics::branching::BranchingDecision; +use crate::heuristics::BranchingDecision; use crate::propagator::SearchPropagator; use crate::search::statistics::Statistics; use crate::common::*;