Skip to content

Commit

Permalink
Updated trait bounds for node attributes, added skeleton for parallel…
Browse files Browse the repository at this point in the history
… feature flag, added tests for parallel feature flag
  • Loading branch information
sriram98v committed Nov 29, 2024
1 parent a36859b commit 2ad5540
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 118 deletions.
16 changes: 9 additions & 7 deletions src/node/simple_rnode.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use itertools::Itertools;
use num::{Float, Num, NumCast, Signed};
use num::{Float, Signed};
use std::{hash::Hash, marker::Sync, fmt::{Debug, Display}, str::FromStr, iter::Sum};

/// Trait bound alias for Edge Weight.
pub trait EdgeWeight: Num + Display + Debug + Clone + PartialOrd + NumCast + Sum + Copy + ToString + Sync + FromStr + Float + Signed{}
pub trait EdgeWeight: Display + Debug + Sum + FromStr + Float + Signed + Sync + Send{}

/// Trait bound alias for Node Weight.
pub trait NodeWeight: Num + Display + Debug + Clone + PartialOrd + NumCast + Sum + Copy + ToString + Sync + FromStr + Float + Signed{}
pub trait NodeWeight: Display + Debug + Sum + FromStr + Float + Signed + Sync + Send{}

/// Trait bound alias for Node Taxa.
pub trait NodeTaxa: Display + Debug + Clone + Sync + FromStr + Ord + Hash + Sync{}
pub trait NodeTaxa: Display + Debug + Clone + FromStr + Ord + Hash + Sync + Send{}
// /// Trait bound alias for Node ID.
// pub trait NodeID: Display + Debug + Hash + Ord + Eq + Copy + Sync{}

Expand Down Expand Up @@ -115,7 +117,7 @@ where
/// A trait describing the behaviour of a Node in a n-ary tree that carries node annotations
pub trait RootedMetaNode: RootedTreeNode {
/// Meta annotation of node
type Meta: Display + Debug + Eq + PartialEq + Clone + Ord + Hash;
type Meta: NodeTaxa;

/// Returns node annotation
fn get_taxa<'a>(&'a self) -> Option<&'a Self::Meta>;
Expand All @@ -127,7 +129,7 @@ pub trait RootedMetaNode: RootedTreeNode {
/// A trait describing the behaviour of a Node in a n-ary tree that has numeric edge annotations
pub trait RootedWeightedNode: RootedTreeNode {
/// Weight of edge leading into node
type Weight: Num + Clone + PartialOrd + NumCast + std::iter::Sum;
type Weight: EdgeWeight;

/// Returns weight of edge leading into node
fn get_weight(&self) -> Option<Self::Weight>;
Expand All @@ -149,7 +151,7 @@ pub trait RootedWeightedNode: RootedTreeNode {
/// A trait describing the behaviour of a Node in a n-ary tree with numeric node annotations
pub trait RootedZetaNode: RootedTreeNode {
/// Zeta annotation of a node
type Zeta: Num + Clone + PartialOrd + NumCast + std::iter::Sum;
type Zeta: NodeWeight;

/// Returns node annotation
fn get_zeta(&self) -> Option<Self::Zeta>;
Expand Down
167 changes: 74 additions & 93 deletions src/tree/distances.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use itertools::Itertools;
use num::{One, Float, NumCast, Signed};
use std::fmt::{Debug, Display};
use num::{One, Float, NumCast};
use std::fmt::Debug;
use vers_vecs::BitVec;

#[cfg(feature = "non_crypto_hash")]
Expand All @@ -9,7 +9,7 @@ use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use std::collections::{HashMap, HashSet};

#[cfg(feature = "parallel")]
use rayon::{prelude::*, iter::ParallelBridge};
use rayon::prelude::*;

use crate::prelude::*;

Expand Down Expand Up @@ -253,66 +253,56 @@ pub trait CopheneticDistance:
PathFunction + RootedMetaTree + Clusters + Ancestors + ContractTree + Debug
where
<Self as RootedTree>::Node: RootedMetaNode + RootedZetaNode,
<<Self as RootedTree>::Node as RootedZetaNode>::Zeta: Signed
+ Clone
+ NumCast
+ std::iter::Sum
+ Debug
+ Display
+ Float
+ PartialOrd
+ Copy
+ Sync,
TreeNodeZeta<Self>: NodeWeight,
{
/// Returns zeta of leaf by taxa
fn get_zeta_taxa(
&self,
taxa: &TreeNodeMeta<Self>,
) -> <<Self as RootedTree>::Node as RootedZetaNode>::Zeta {
) -> TreeNodeZeta<Self> {
self.get_zeta(self.get_taxa_node_id(taxa).unwrap()).unwrap()
}

/// Reurns the nth norm of an iterator composed of floating point values
fn compute_norm(
vector: impl IntoIterator<Item = <<Self as RootedTree>::Node as RootedZetaNode>::Zeta>,
vector: impl Iterator<Item = TreeNodeZeta<Self>>,
norm: u32,
) -> <<Self as RootedTree>::Node as RootedZetaNode>::Zeta {
) -> TreeNodeZeta<Self> {
if norm == 1 {
return vector.into_iter().map(|x| x.clone()).sum();
return vector.map(|x| x.clone()).sum();
}
vector
.into_iter()
.map(|x| {
let mut out = <TreeNodeZeta<Self>>::one();
for _ in 0..norm{
out = out* x.clone();
}
out
})
.sum::<<<Self as RootedTree>::Node as RootedZetaNode>::Zeta>()
.sum::<TreeNodeZeta<Self>>()
.powf(
<<<Self as RootedTree>::Node as RootedZetaNode>::Zeta as NumCast>::from(norm)
<TreeNodeZeta<Self> as NumCast>::from(norm)
.unwrap()
.powi(-1),
)
}

#[cfg(feature = "parallel")]
/// Returns the vector norm for an iterator
fn compute_norm_par(
vector: impl IntoIterator<Item = <<Self as RootedTree>::Node as RootedZetaNode>::Zeta> + IntoParallelIterator<Item = <<Self as RootedTree>::Node as RootedZetaNode>::Zeta>,
vector: impl Iterator<Item = TreeNodeZeta<Self>>,
norm: u32,
) -> <<Self as RootedTree>::Node as RootedZetaNode>::Zeta {
) -> TreeNodeZeta<Self> {
if norm == 1 {
return vector.into_iter().map(|x| x.clone()).sum();
return vector.map(|x| x.clone()).sum();
}
vector
.into_par_iter()
.map(|x| {
x.clone().powi(norm as i32)
})
.sum::<<<Self as RootedTree>::Node as RootedZetaNode>::Zeta>()
.sum::<TreeNodeZeta<Self>>()
.powf(
<<<Self as RootedTree>::Node as RootedZetaNode>::Zeta as NumCast>::from(norm)
<TreeNodeZeta<Self> as NumCast>::from(norm)
.unwrap()
.powi(-1),
)
Expand All @@ -323,16 +313,16 @@ where
&'a self,
tree: &'a Self,
norm: u32,
) -> <<Self as RootedTree>::Node as RootedZetaNode>::Zeta {
) -> TreeNodeZeta<Self> {
if !self.is_all_zeta_set() || !tree.is_all_zeta_set() {
panic!("Zeta values not set");
}
let binding1 = self
.get_taxa_space()
.collect::<HashSet<&<<Self as RootedTree>::Node as RootedMetaNode>::Meta>>();
.collect::<HashSet<&TreeNodeMeta<Self>>>();
let binding2 = tree
.get_taxa_space()
.collect::<HashSet<&<<Self as RootedTree>::Node as RootedMetaNode>::Meta>>();
.collect::<HashSet<&TreeNodeMeta<Self>>>();
let taxa_set = binding1.intersection(&binding2).cloned();

self.cophen_dist_by_taxa(tree, norm, taxa_set)
Expand All @@ -344,19 +334,19 @@ where
&'a self,
tree: &'a Self,
norm: u32,
) -> <<Self as RootedTree>::Node as RootedZetaNode>::Zeta {
) -> TreeNodeZeta<Self> {
if !self.is_all_zeta_set() || !tree.is_all_zeta_set() {
panic!("Zeta values not set");
}
let binding1 = self
.get_taxa_space()
.collect::<HashSet<&<<Self as RootedTree>::Node as RootedMetaNode>::Meta>>();
.collect::<HashSet<&TreeNodeMeta<Self>>>();
let binding2 = tree
.get_taxa_space()
.collect::<HashSet<&<<Self as RootedTree>::Node as RootedMetaNode>::Meta>>();
.collect::<HashSet<&TreeNodeMeta<Self>>>();
let taxa_set = binding1.intersection(&binding2).cloned().map(|x| x.clone()).collect_vec();

self.cophen_dist_by_taxa_par(tree, norm, taxa_set)
self.cophen_dist_by_taxa_par(tree, norm, taxa_set.iter())
}

#[cfg(feature = "parallel")]
Expand All @@ -365,39 +355,35 @@ where
&'a self,
tree: &'a Self,
norm: u32,
taxa_set: impl IntoIterator<Item = TreeNodeMeta<Self>> + Clone,
) -> <<Self as RootedTree>::Node as RootedZetaNode>::Zeta {
taxa_set: impl Iterator<Item = &'a TreeNodeMeta<Self>> + Send,
) -> TreeNodeZeta<Self> {
// let taxa_set = taxa_set.collect_vec();
let cophen_vec = taxa_set.into_iter()
.par_bridge()
let cophen_vec = taxa_set
.combinations_with_replacement(2)
.into_iter()
.map(|x| match x[0] == x[1] {
true => vec![x[0].clone()],
false => x,
})
.map(|x| match x.len() {
1 => {
let zeta_1 = self.get_zeta_taxa(&x[0]);
let zeta_2 = tree.get_zeta_taxa(&x[0]);
(zeta_1 - zeta_2).abs()
}
_ => {
let self_ids = x
.iter()
.map(|a| self.get_taxa_node_id(a).unwrap())
.collect_vec();
let tree_ids = x
.iter()
.map(|a| tree.get_taxa_node_id(a).unwrap())
.collect_vec();
let t_lca_id = self.get_lca_id(self_ids.as_slice());
let t_hat_lca_id = tree.get_lca_id(tree_ids.as_slice());
let zeta_1 = self.get_zeta(t_lca_id).unwrap();
let zeta_2 = tree.get_zeta(t_hat_lca_id).unwrap();
(zeta_1 - zeta_2).abs()
}
})
.collect_vec();
true => {
let zeta_1 = self.get_zeta_taxa(x[0]);
let zeta_2 = tree.get_zeta_taxa(x[0]);
(zeta_1 - zeta_2).abs()
},
false => {
let self_ids = x
.iter()
.map(|a| self.get_taxa_node_id(a).unwrap())
.collect_vec();
let tree_ids = x
.iter()
.map(|a| tree.get_taxa_node_id(a).unwrap())
.collect_vec();
let t_lca_id = self.get_lca_id(self_ids.as_slice());
let t_hat_lca_id = tree.get_lca_id(tree_ids.as_slice());
let zeta_1 = self.get_zeta(t_lca_id).unwrap();
let zeta_2 = tree.get_zeta(t_hat_lca_id).unwrap();
(zeta_1 - zeta_2).abs()
},
});
// .collect::<Vec<TreeNodeZeta<Self>>>();

Self::compute_norm_par(cophen_vec, norm)
}
Expand All @@ -407,39 +393,34 @@ where
&'a self,
tree: &'a Self,
norm: u32,
taxa_set: impl Iterator<Item = &'a TreeNodeMeta<Self>> + Clone,
) -> <<Self as RootedTree>::Node as RootedZetaNode>::Zeta {
let taxa_set = taxa_set.collect_vec();
taxa_set: impl Iterator<Item = &'a TreeNodeMeta<Self>>,
) -> TreeNodeZeta<Self> {
// let taxa_set = taxa_set.collect_vec();
let cophen_vec = taxa_set
.iter()
// .iter()
.combinations_with_replacement(2)
.map(|x| match x[0] == x[1] {
true => vec![x[0]],
false => x,
})
.map(|x| match x.len() {
1 => {
let zeta_1 = self.get_zeta_taxa(x[0]);
let zeta_2 = tree.get_zeta_taxa(x[0]);
(zeta_1 - zeta_2).abs()
}
_ => {
let self_ids = x
.iter()
.map(|a| self.get_taxa_node_id(a).unwrap())
.collect_vec();
let tree_ids = x
.iter()
.map(|a| tree.get_taxa_node_id(a).unwrap())
.collect_vec();
let t_lca_id = self.get_lca_id(self_ids.as_slice());
let t_hat_lca_id = tree.get_lca_id(tree_ids.as_slice());
let zeta_1 = self.get_zeta(t_lca_id).unwrap();
let zeta_2 = tree.get_zeta(t_hat_lca_id).unwrap();
(zeta_1 - zeta_2).abs()
}
})
.collect_vec();
true => {
let zeta_1 = self.get_zeta_taxa(x[0]);
let zeta_2 = tree.get_zeta_taxa(x[0]);
(zeta_1 - zeta_2).abs()
},
false => {
let self_ids = x
.iter()
.map(|a| self.get_taxa_node_id(a).unwrap())
.collect_vec();
let tree_ids = x
.iter()
.map(|a| tree.get_taxa_node_id(a).unwrap())
.collect_vec();
let t_lca_id = self.get_lca_id(self_ids.as_slice());
let t_hat_lca_id = tree.get_lca_id(tree_ids.as_slice());
let zeta_1 = self.get_zeta(t_lca_id).unwrap();
let zeta_2 = tree.get_zeta(t_hat_lca_id).unwrap();
(zeta_1 - zeta_2).abs()
},
});

Self::compute_norm(cophen_vec, norm)
}
Expand Down
41 changes: 23 additions & 18 deletions tests/tree-tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,29 +290,34 @@ fn bipartitions() {
#[test]
#[cfg(feature = "parallel")]
fn compute_norm_parallel() {
for norm in 1..100{
for norm in 1..10{
let x = (1..1000).map(|x| x as f32).collect_vec();
let y = x.clone();
assert_eq!(PhyloTree::compute_norm(x, 1), PhyloTree::compute_norm_par(y, 1));
assert!((PhyloTree::compute_norm(x.into_iter(), norm)-PhyloTree::compute_norm_par(y.into_iter(), norm)).abs()<0.1);
}

let x = (1..3).combinations_with_replacement(2).collect_vec();
let y = (1..3).combinations_with_replacement(2).par_bridge().map(|x| x[0]+x[1]).collect::<Vec<_>>();

dbg!(x, y);
}

#[test]
#[cfg(feature = "parallel")]
fn cophenetic_dist_par() {
fn depth(tree: &PhyloTree, node_id: usize) -> f32 {
tree.depth(node_id) as f32
}
let t1_input_str: String = String::from("((A,B),C);");
let t2_input_str: String = String::from("(A,(B,C));");
let mut t1 = PhyloTree::from_newick(t1_input_str.as_bytes()).unwrap();
let mut t2 = PhyloTree::from_newick(t2_input_str.as_bytes()).unwrap();
// #[test]
// #[cfg(feature = "parallel")]
// fn cophenetic_dist_par() {
// fn depth(tree: &PhyloTree, node_id: usize) -> f32 {
// tree.depth(node_id) as f32
// }
// let t1_input_str: String = String::from("((A,B),C);");
// let t2_input_str: String = String::from("(A,(B,C));");
// let mut t1 = PhyloTree::from_newick(t1_input_str.as_bytes()).unwrap();
// let mut t2 = PhyloTree::from_newick(t2_input_str.as_bytes()).unwrap();

t1.precompute_constant_time_lca();
t2.precompute_constant_time_lca();
// t1.precompute_constant_time_lca();
// t2.precompute_constant_time_lca();

t1.set_zeta(depth).unwrap();
t2.set_zeta(depth).unwrap();
// t1.set_zeta(depth).unwrap();
// t2.set_zeta(depth).unwrap();

assert_eq!(t1.cophen_dist_par(&t2, 1), 4_f32);
}
// assert_eq!(t1.cophen_dist_par(&t2, 1), 4_f32);
// }

0 comments on commit 2ad5540

Please sign in to comment.