Skip to content

Commit

Permalink
[pyschlandals] get output of distribution nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandreDubray committed Sep 29, 2023
1 parent 1ded010 commit 815c133
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 194 deletions.
11 changes: 11 additions & 0 deletions pyschlandals/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,17 @@ impl PyDac {
self.dac.get_distribution_domain_size(DistributionNodeIndex(distribution))
}

/// Returns the pair (circuit node, value index) of the output of the distribution at its given output-index
pub fn get_distribution_node_output_at(&self, distribution: usize, index: usize) -> (usize, usize) {
let x = self.dac.get_distribution_output_at(DistributionNodeIndex(distribution), index);
(x.0.0, x.1)
}

/// Returns the number of output of a distribution node
pub fn get_distribution_number_output(&self, distribution: usize) -> usize {
self.dac.get_distribution_number_output(DistributionNodeIndex(distribution))
}

/// Returns the probability, of the given distribution, at the given index
pub fn get_distribution_probability(&self, distribution: usize, probability_index: usize) -> f64 {
self.dac.get_distribution_probability_at(DistributionNodeIndex(distribution), probability_index)
Expand Down
13 changes: 12 additions & 1 deletion src/compiler/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub struct CircuitNode {
/// A distribution node, an input of the circuit. Each distribution node holds the distribution's parameter as well as the outputs.
/// For each output node, it also stores the value that must be sent to the output (as an index of the probability vector).
struct DistributionNode {

/// Probabilities of the distribution
probabilities: Vec<f64>,
/// Outputs of the node
Expand Down Expand Up @@ -157,7 +158,7 @@ impl Dac {

/// Adds `output` to the outputs of `node` and `node` to the inputs of `output`. Note that this
/// function uses the vectors in each node. They are transferred afterward in the `outputs` vector.
pub fn add_spnode_output(&mut self, node: CircuitNodeIndex, output: CircuitNodeIndex) {
pub fn add_circuit_node_output(&mut self, node: CircuitNodeIndex, output: CircuitNodeIndex) {
self.nodes[node.0].outputs.push(output);
self.nodes[node.0].number_outputs += 1;
self.nodes[output.0].inputs.insert(node);
Expand Down Expand Up @@ -509,6 +510,16 @@ impl Dac {
self.distribution_nodes[distribution.0].probabilities[index]
}

/// Returns the pair (circuit_node, index) for the output of the distribution at the given index
pub fn get_distribution_output_at(&self, distribution: DistributionNodeIndex, index: usize) -> (CircuitNodeIndex, usize) {
self.distribution_nodes[distribution.0].outputs[index]
}

/// Returns the number of output of a distribution node
pub fn get_distribution_number_output(&self, distribution: DistributionNodeIndex) -> usize {
self.distribution_nodes[distribution.0].outputs.len()
}

// --- SETTERS --- //

/// Set the probability of the distribution, at the given index, to the given value
Expand Down
44 changes: 21 additions & 23 deletions src/compiler/exact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use search_trail::{StateManager, SaveAndRestore};

use crate::core::components::{ComponentExtractor, ComponentIndex};
use crate::core::graph::*;
use crate::heuristics::branching::BranchingDecision;
use crate::heuristics::BranchingDecision;
use crate::propagator::CompiledPropagator;
use crate::common::*;
use crate::compiler::circuit::*;
Expand Down Expand Up @@ -77,53 +77,51 @@ where
}
}

fn expand_sum_node(&mut self, spn: &mut Dac, component: ComponentIndex, distribution: DistributionIndex) -> Option<CircuitNodeIndex> {
fn expand_sum_node(&mut self, dac: &mut Dac, component: ComponentIndex, distribution: DistributionIndex) -> Option<CircuitNodeIndex> {
let mut children: Vec<CircuitNodeIndex> = vec![];
for variable in self.graph.distribution_variable_iter(distribution) {
self.state.save_state();
match self.propagator.propagate_variable(variable, true, &mut self.graph, &mut self.state, component, &self.component_extractor) {
Err(_) => {

},
Err(_) => { },
Ok(_) => {
if let Some(child) = self.expand_prod_node(spn, component) {
if let Some(child) = self.expand_prod_node(dac, component) {
children.push(child);
}
}
}
self.state.restore_state();
}
if !children.is_empty() {
let node = spn.add_sum_node();
let node = dac.add_sum_node();
for child in children {
spn.add_spnode_output(child, node);
dac.add_circuit_node_output(child, node);
}
Some(node)
} else {
None
}
}

fn expand_prod_node(&mut self, spn: &mut Dac, component: ComponentIndex) -> Option<CircuitNodeIndex> {
fn expand_prod_node(&mut self, dac: &mut Dac, component: ComponentIndex) -> Option<CircuitNodeIndex> {
let mut prod_node: Option<CircuitNodeIndex> = if self.propagator.has_assignments() || self.propagator.has_unconstrained_distribution() {
let node = spn.add_prod_node();
let node = dac.add_prod_node();
for (distribution, variable, value) in self.propagator.assignments_iter() {
if value {
let value_id = variable.0 - self.graph.get_distribution_start(distribution).0;
spn.add_distribution_output(distribution, node, value_id);
dac.add_distribution_output(distribution, node, value_id);
}
}

for distribution in self.propagator.unconstrained_distributions_iter() {
if self.graph.distribution_number_false(distribution, &self.state) != 0 {
let sum_node = spn.add_sum_node();
let sum_node = dac.add_sum_node();
for variable in self.graph.distribution_variable_iter(distribution) {
if !self.graph.is_variable_fixed(variable, &self.state) {
let value_id = variable.0 - self.graph.get_distribution_start(distribution).0;
spn.add_distribution_output(distribution, sum_node, value_id);
dac.add_distribution_output(distribution, sum_node, value_id);
}
}
spn.add_spnode_output(sum_node, node);
dac.add_circuit_node_output(sum_node, node);
}
}
Some(node)
Expand All @@ -137,7 +135,7 @@ where
match self.cache.get(&bit_repr) {
None => {
if let Some(distribution) = self.branching_heuristic.branch_on(&self.graph, &self.state, &self.component_extractor, sub_component) {
if let Some(child) = self.expand_sum_node(spn, sub_component, distribution) {
if let Some(child) = self.expand_sum_node(dac, sub_component, distribution) {
sum_children.push(child);
self.cache.insert(bit_repr, Some(child));
} else {
Expand Down Expand Up @@ -166,11 +164,11 @@ where
}
}
if !sum_children.is_empty() && prod_node.is_none() {
prod_node = Some(spn.add_prod_node());
prod_node = Some(dac.add_prod_node());
}
if let Some(node) = prod_node {
for child in sum_children {
spn.add_spnode_output(child, node);
dac.add_circuit_node_output(child, node);
}
}
prod_node
Expand All @@ -185,14 +183,14 @@ where
Err(_) => None,
Ok(_) => {
self.branching_heuristic.init(&self.graph, &self.state);
let mut spn = Dac::new(&self.graph);
match self.expand_prod_node(&mut spn, ComponentIndex(0)) {
let mut dac = Dac::new(&self.graph);
match self.expand_prod_node(&mut dac, ComponentIndex(0)) {
None => None,
Some(_) => {
spn.remove_dead_ends();
spn.reduce();
spn.layerize();
Some(spn)
dac.remove_dead_ends();
dac.reduce();
dac.layerize();
Some(dac)
}
}
}
Expand Down
25 changes: 13 additions & 12 deletions src/core/components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub struct ComponentExtractor {
/// vector
clauses: Vec<ClauseIndex>,
/// The vector mapping for each `ClauseIndex` its position in `clauses`
positions: Vec<usize>,
clause_positions: Vec<usize>,
/// Holds the components computed by the extractor during the search
components: Vec<Component>,
/// The index of the first component of the current node in the search tree
Expand Down Expand Up @@ -110,7 +110,7 @@ impl ComponentExtractor {
/// Creates a new component extractor for the implication graph `g`
pub fn new(g: &Graph, state: &mut StateManager) -> Self {
let nodes = (0..g.number_clauses()).map(ClauseIndex).collect();
let positions = (0..g.number_clauses()).collect();
let clause_positions = (0..g.number_clauses()).collect();
let distributions = (0..g.number_distributions()).map(DistributionIndex).collect();
let distribution_positions = (0..g.number_distributions()).collect();
let components = vec![Component {
Expand All @@ -122,7 +122,7 @@ impl ComponentExtractor {
}];
Self {
clauses: nodes,
positions,
clause_positions,
components,
base: state.manage_usize(0),
limit: state.manage_usize(1),
Expand All @@ -146,7 +146,7 @@ impl ComponentExtractor {
) -> bool {
// if the clause has already been visited, then its position in the component must
// be between [start..(start + size)].
let clause_pos = self.positions[clause.0];
let clause_pos = self.clause_positions[clause.0];
g.is_clause_constrained(clause, state) && !(comp_start <= clause_pos && clause_pos < (comp_start + *comp_size))
}

Expand All @@ -171,15 +171,15 @@ impl ComponentExtractor {
if self.is_node_visitable(g, clause, comp_start, comp_size, state) {
*hash ^= g.get_clause_random(clause);
// The clause is swap with the clause at position comp_sart + comp_size
let current_pos = self.positions[clause.0];
let current_pos = self.clause_positions[clause.0];
let new_pos = comp_start + *comp_size;
// Only move the nodes if it is not already in position
// Not sure if this optimization is worth in practice
if new_pos != current_pos {
let moved_node = self.clauses[new_pos];
self.clauses.as_mut_slice().swap(new_pos, current_pos);
self.positions[clause.0] = new_pos;
self.positions[moved_node.0] = current_pos;
self.clause_positions[clause.0] = new_pos;
self.clause_positions[moved_node.0] = current_pos;
}
*comp_size += 1;

Expand Down Expand Up @@ -329,6 +329,7 @@ impl ComponentExtractor {
let end = start + self.components[component.0].number_distribution;
self.distributions[start..end].iter().copied()
}

}

/// This structure is used to implement a simple component detector that always returns one
Expand Down Expand Up @@ -433,7 +434,7 @@ mod test_component_detection {
let mut state = StateManager::default();
let g = get_graph(&mut state);
let extractor = ComponentExtractor::new(&g, &mut state);
assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions);
assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions);
check_component(&extractor, 0, 6, (0..6).map(ClauseIndex).collect::<Vec<ClauseIndex>>());
assert_eq!(0, state.get_usize(extractor.base));
assert_eq!(1, state.get_usize(extractor.limit));
Expand All @@ -450,7 +451,7 @@ mod test_component_detection {
extractor.detect_components(&mut g, &mut state, ComponentIndex(0), &mut propagator);

assert_eq!(2, extractor.number_components(&state));
assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions);
assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions);
check_component(&extractor, 0, 4, (0..4).map(ClauseIndex).collect::<Vec<ClauseIndex>>());
check_component(&extractor, 5, 6, vec![ClauseIndex(5)]);

Expand All @@ -473,7 +474,7 @@ mod test_component_detection {
extractor.detect_components(&mut g, &mut state, ComponentIndex(0), &mut propagator);

assert_eq!(2, extractor.number_components(&state));
assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions);
assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions);
check_component(&extractor, 0, 4, (0..4).map(ClauseIndex).collect::<Vec<ClauseIndex>>());
check_component(&extractor, 5, 6, vec![ClauseIndex(5)]);

Expand All @@ -487,13 +488,13 @@ mod test_component_detection {
state.restore_state();

assert_eq!(2, extractor.number_components(&state));
assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions);
assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions);
check_component(&extractor, 0, 4, (0..4).map(ClauseIndex).collect::<Vec<ClauseIndex>>());
check_component(&extractor, 5, 6, vec![ClauseIndex(5)]);

state.restore_state();

assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.positions);
assert_eq!(vec![0, 1, 2, 3, 4, 5], extractor.clause_positions);
check_component(&extractor, 0, 6, (0..6).map(ClauseIndex).collect::<Vec<ClauseIndex>>());
}
}
40 changes: 40 additions & 0 deletions src/core/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,46 @@ impl Graph {
}
}

fn get_distribution_maximum(&self, distribution: DistributionIndex, state: &StateManager) -> f64 {
let mut max = 0.0;
for v in self.distribution_variable_iter(distribution) {
if !self.is_variable_fixed(v, state) {
let proba = self.get_variable_weight(v).unwrap();
if max < proba {
max = proba;
}
}
}
max
}

/// Returns the distribution with the unassgined variable with the highest probability.
pub fn get_clause_active_distribution_highest_value(&self, clause: ClauseIndex, state: &StateManager) -> Option<(DistributionIndex, f64)> {
let number_probabilistic_watched = self.get_clause_body_bounds_probabilistic(clause, state);
if number_probabilistic_watched == 0 {
None
} else if number_probabilistic_watched == 1 {
let d = self.get_variable_distribution(self.clauses[clause.0].body_probabilistic[0]).unwrap();
let proba = self.get_distribution_maximum(d, state);
Some((d, proba))
} else {
let d1 = self.get_variable_distribution(self.clauses[clause.0].body_probabilistic[0]).unwrap();
let d2 = self.get_variable_distribution(self.clauses[clause.0].body_probabilistic[1]).unwrap();
if d1 == d2 {
let proba_1 = self.get_distribution_maximum(d1, state);
Some((d1, proba_1))
} else {
let proba_1 = self.get_distribution_maximum(d1, state);
let proba_2 = self.get_distribution_maximum(d2, state);
if proba_1 < proba_2 {
Some((d2, proba_2))
} else {
Some((d1, proba_1))
}
}
}
}

/// Returns the number of constrained parents of the clause
pub fn get_clause_number_parents(&self, clause: ClauseIndex, state: &StateManager) -> usize {
self.clauses[clause.0].parents.len(state)
Expand Down
53 changes: 53 additions & 0 deletions src/heuristics/branching_approximate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//Schlandals
//Copyright (C) 2022-2023 A. Dubray
//
//This program is free software: you can redistribute it and/or modify
//it under the terms of the GNU Affero General Public License as published by
//the Free Software Foundation, either version 3 of the License, or
//(at your option) any later version.
//
//This program is distributed in the hope that it will be useful,
//but WITHOUT ANY WARRANTY; without even the implied warranty of
//MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
//GNU Affero General Public License for more details.
//
//You should have received a copy of the GNU Affero General Public License
//along with this program. If not, see <http://www.gnu.org/licenses/>.

use search_trail::StateManager;
use crate::core::components::{ComponentExtractor, ComponentIndex};
use crate::core::graph::{DistributionIndex, Graph};
use crate::heuristics::BranchingDecision;


#[derive(Default)]
pub struct MaxProbability;

impl BranchingDecision for MaxProbability {
fn branch_on(
&mut self,
g: &Graph,
state: &StateManager,
component_extractor: &ComponentExtractor,
component: ComponentIndex,
) -> Option<DistributionIndex> {
let mut best_score = usize::MAX;
let mut best_distribution: Option<DistributionIndex> = None;
let mut best_tie = 0.0;
for clause in component_extractor.component_iter(component) {
if g.is_clause_constrained(clause, state) && g.clause_has_probabilistic(clause, state) {
let score = g.get_clause_number_parents(clause, state);
let (d, proba) = g.get_clause_active_distribution_highest_value(clause, state).unwrap();
if score < best_score || (score == best_score && proba > best_tie) {
best_score = score;
best_tie = proba;
best_distribution = Some(d);
}
}
}
best_distribution
}

fn init(&mut self, _g: &Graph, _state: &StateManager) {}

}
Loading

0 comments on commit 815c133

Please sign in to comment.