Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added general form of transitive closure #1257

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
features:
- |
Added a new function ``transitive_closure`` to rustworkx which returns the
transitive closure of a graph. The transitive closure of G = (V,E) is a graph
G+ = (V,E+) such that for all v, w in V there is an edge (v, w) in E+ if and
only if there is a path from v to w in G.
1 change: 1 addition & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ from .rustworkx import graph_tensor_product as graph_tensor_product
from .rustworkx import graph_token_swapper as graph_token_swapper
from .rustworkx import digraph_transitivity as digraph_transitivity
from .rustworkx import graph_transitivity as graph_transitivity
from .rustworkx import transitive_closure as transitive_closure
from .rustworkx import digraph_bfs_search as digraph_bfs_search
from .rustworkx import graph_bfs_search as graph_bfs_search
from .rustworkx import digraph_dfs_search as digraph_dfs_search
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(minimum_spanning_edges))?;
m.add_wrapped(wrap_pyfunction!(minimum_spanning_tree))?;
m.add_wrapped(wrap_pyfunction!(graph_transitivity))?;
m.add_wrapped(wrap_pyfunction!(transitive_closure))?;
m.add_wrapped(wrap_pyfunction!(digraph_transitivity))?;
m.add_wrapped(wrap_pyfunction!(graph_token_swapper))?;
m.add_wrapped(wrap_pyfunction!(graph_core_number))?;
Expand Down
65 changes: 65 additions & 0 deletions src/transitivity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,20 @@
use super::{digraph, graph};
use hashbrown::HashSet;

use crate::digraph::PyDiGraph;
use petgraph::algo::kosaraju_scc;
use petgraph::algo::DfsSpace;
use petgraph::graph::DiGraph;
use pyo3::prelude::*;

use petgraph::visit::EdgeRef;
use petgraph::visit::IntoEdgeReferences;
use petgraph::visit::NodeCount;
use petgraph::graph::NodeIndex;
use rayon::prelude::*;

use rustworkx_core::traversal::build_transitive_closure_dag;

fn _graph_triangles(graph: &graph::PyGraph, node: usize) -> (usize, usize) {
let mut triangles: usize = 0;

Expand Down Expand Up @@ -186,3 +195,59 @@ pub fn digraph_transitivity(graph: &digraph::PyDiGraph) -> f64 {
_ => triangles as f64 / triples as f64,
}
}

/// Returns the transitive closure of a graph
#[pyfunction]
#[pyo3(text_signature = "(graph, /")]
pub fn transitive_closure(py: Python, graph: &PyDiGraph) -> PyResult<PyDiGraph> {
let sccs = kosaraju_scc(&graph.graph);

let mut condensed_graph = DiGraph::new();
let mut scc_nodes = Vec::new();
let mut scc_map: Vec<NodeIndex> = vec![NodeIndex::end(); graph.node_count()];

for scc in &sccs {
let scc_node = condensed_graph.add_node(());
scc_nodes.push(scc_node);
for node in scc {
scc_map[node.index()] = scc_node;
}
}
for edge in graph.graph.edge_references() {
let (source, target) = (edge.source(), edge.target());

if scc_map[source.index()] != scc_map[target.index()] {
condensed_graph.add_edge(scc_map[source.index()], scc_map[target.index()], ());
}
}

let closure_graph_result = build_transitive_closure_dag(condensed_graph, None, || {});
let out_graph = closure_graph_result.unwrap();

let mut new_graph = graph.graph.clone();
new_graph.clear();

let mut result_map: Vec<NodeIndex> = vec![NodeIndex::end(); out_graph.node_count()];
for (_index, node) in out_graph.node_indices().enumerate() {
let result_node = new_graph.add_node(py.None());
result_map[node.index()] = result_node;
}
for edge in out_graph.edge_references() {
let (source, target) = (edge.source(), edge.target());
new_graph.add_edge(
result_map[source.index()],
result_map[target.index()],
py.None(),
);
}
let out = PyDiGraph {
graph: new_graph,
cycle_state: DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph: true,
attrs: py.None(),
};

Ok(out)
}
70 changes: 70 additions & 0 deletions tests/graph/test_transitive_closure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import rustworkx


class TestTransitive(unittest.TestCase):
def test_transitive_closure(self):

graph = rustworkx.PyDiGraph()
graph.add_nodes_from(list(range(4)))
graph.add_edge(0, 1, ())
graph.add_edge(1, 2, ())
graph.add_edge(2, 0, ())
graph.add_edge(2, 3, ())

closure_graph = rustworkx.transitive_closure(graph)
self.expected_edges = [
(0, 1),
(0, 2),
(0, 3),
(1, 0),
(1, 2),
(1, 3),
(2, 0),
(2, 1),
(2, 3),
]

self.assertEqualEdgeList(self.expected_edges, closure_graph.edge_list())

def test_transitive_closure_single_node(self):
graph = rustworkx.PyDiGraph()
graph.add_node(())
closure_graph = rustworkx.transitive_closure(graph)
expected_edges = []
self.assertEqualEdgeList(expected_edges, closure_graph.edge_list())

def test_transitive_closure_no_edges(self):
graph = rustworkx.PyDiGraph()
graph.add_nodes_from(list(range(4)))
closure_graph = rustworkx.transitive_closure(graph)
expected_edges = []
self.assertEqualEdgeList(expected_edges, closure_graph.edge_list())

def test_transitive_closure_complete_graph(self):
graph = rustworkx.PyDiGraph()
graph.add_nodes_from(list(range(4)))
for i in range(4):
for j in range(4):
if i != j:
graph.add_edge(i, j, ())
closure_graph = rustworkx.transitive_closure(graph)
expected_edges = [(i, j) for i in range(4) for j in range(4) if i != j]
self.assertEqualEdgeList(expected_edges, closure_graph.edge_list())

def assertEqualEdgeList(self, expected, actual):
for edge in actual:
self.assertTrue(edge in expected)
Loading