From 2eea0f39c5e5b76e5833e780d8f35b5075ba4379 Mon Sep 17 00:00:00 2001 From: Thomas Chaplin Date: Wed, 6 Mar 2024 08:33:11 +0000 Subject: [PATCH] Allow arbitrary (s, t) list for `l_homology` binding --- Cargo.lock | 2 +- Cargo.toml | 2 +- examples/simple.py | 6 ++++++ src/bindings.rs | 51 ++++++++++++++++++++++++++++++++-------------- 4 files changed, 44 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 90b924a..7307aa7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -124,7 +124,7 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "gramag" -version = "0.3.0" +version = "0.3.1" dependencies = [ "anyhow", "dashmap", diff --git a/Cargo.toml b/Cargo.toml index 64d9919..99defda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gramag" -version = "0.3.0" +version = "0.3.1" edition = "2021" [dependencies] diff --git a/examples/simple.py b/examples/simple.py index 6a62fc3..401d59e 100644 --- a/examples/simple.py +++ b/examples/simple.py @@ -61,6 +61,12 @@ print(ds.representatives) print("") +print("Fixed l, custom (s, t) list - computed in parallel!") +ds2 = mg.l_homology(1, representatives=True, node_pairs=[(s, 5) for s in range(6)]) +print(ds2.ranks) +print(ds2.representatives) +print("") + print("Errors:") try: mg.l_homology(12) diff --git a/src/bindings.rs b/src/bindings.rs index 2f3d9ab..e42863e 100644 --- a/src/bindings.rs +++ b/src/bindings.rs @@ -137,23 +137,44 @@ impl MagGraph { Ok(PyStlHomology(Arc::new(homology))) } - // TODO: New method - allow arbitrary (s, t) list - fn l_homology(&self, l: usize, representatives: Option) -> Result { + fn l_homology( + &self, + l: usize, + representatives: Option, + node_pairs: Option>, + ) -> Result { self.check_l(l)?; let representatives = representatives.unwrap_or(false); - let stl_homologies: Vec<_> = self - .digraph - .node_identifiers() - .flat_map(|s| self.digraph.node_identifiers().map(move |t| (s, t))) - .par_bridge() - .map(|node_pair| { - ( - (node_pair, l), - Arc::new(self.inner_compute_stl_homology(node_pair, l, representatives)), - ) - }) - .collect(); - Ok(PyDirectSum(DirectSum::new(stl_homologies.into_iter()))) + let compute_stl_homologies = |node_pairs: Box< + dyn Iterator, NodeIndex)> + Send, + >| { + node_pairs + .par_bridge() + .map(|node_pair| { + ( + (node_pair, l), + Arc::new(self.inner_compute_stl_homology(node_pair, l, representatives)), + ) + }) + .collect::>() + .into_iter() + }; + + let stl_homologies = if let Some(u32_node_pairs) = node_pairs { + compute_stl_homologies(Box::new( + u32_node_pairs + .into_iter() + .map(|(s, t)| (NodeIndex::from(s), NodeIndex::from(t))), + )) + } else { + compute_stl_homologies(Box::new( + self.digraph + .node_identifiers() + .flat_map(|s| self.digraph.node_identifiers().map(move |t| (s, t))), + )) + }; + + Ok(PyDirectSum(DirectSum::new(stl_homologies))) } }