From 0893b5ca041156228f2d058007b1ac256cc5f840 Mon Sep 17 00:00:00 2001 From: sriram98v Date: Tue, 12 Nov 2024 10:53:43 -0600 Subject: [PATCH] removed bitvec crate and used vers_vec::BitVec; updated benchmarks, updated rfs --- benches/main.rs | 80 +++++++++++++++++++++++++-------------- src/iter/node_iter.rs | 75 ++++++++++++------------------------- src/tree.rs | 1 + src/tree/distances.rs | 87 ++++++++++++++++++++++++++++++------------- tests/tree-tests.rs | 9 +++-- 5 files changed, 143 insertions(+), 109 deletions(-) diff --git a/benches/main.rs b/benches/main.rs index 7c3a676..fe50d1e 100644 --- a/benches/main.rs +++ b/benches/main.rs @@ -3,50 +3,49 @@ use phylo::prelude::*; use phylo::tree::SimpleRootedTree; use rand::{seq::IteratorRandom, thread_rng}; -const NUM_TAXA: usize = 4000; const NORM: u32 = 1; fn main() { divan::main(); } -#[divan::bench] -fn benchmark_constant_time_lca(bencher: divan::Bencher) { - let mut tree = SimpleRootedTree::yule(NUM_TAXA); +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] +fn benchmark_constant_time_lca(bencher: divan::Bencher, taxa_size: usize) { + let mut tree = SimpleRootedTree::yule(taxa_size); tree.precompute_constant_time_lca(); bencher.bench(|| tree.get_lca_id(vec![10, 20].as_slice())); } -#[divan::bench] -fn benchmark_lca(bencher: divan::Bencher) { - let tree = SimpleRootedTree::yule(NUM_TAXA); +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] +fn benchmark_lca(bencher: divan::Bencher, taxa_size: usize) { + let tree = SimpleRootedTree::yule(taxa_size); bencher.bench(|| tree.get_lca_id(vec![10, 20].as_slice())); } -#[divan::bench] -fn benchmark_yule(bencher: divan::Bencher) { - bencher.bench(|| SimpleRootedTree::yule(NUM_TAXA)); +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] +fn benchmark_yule(bencher: divan::Bencher, taxa_size: usize) { + bencher.bench(|| SimpleRootedTree::yule(taxa_size)); } -#[divan::bench] -fn benchmark_precompute_rmq(bencher: divan::Bencher) { +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] +fn benchmark_precompute_rmq(bencher: divan::Bencher, taxa_size: usize) { bencher - .with_inputs(|| SimpleRootedTree::yule(NUM_TAXA)) + .with_inputs(|| SimpleRootedTree::yule(taxa_size)) .bench_refs(|tree| { tree.precompute_constant_time_lca(); }); } -#[divan::bench] -fn benchmark_cophen_dist_naive(bencher: divan::Bencher) { +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] +fn benchmark_cophen_dist_naive(bencher: divan::Bencher, taxa_size: usize) { bencher .with_inputs(|| { fn depth(tree: &SimpleRootedTree, node_id: usize) -> f32 { EulerWalk::get_node_depth(tree, node_id) as f32 } - let mut t1 = SimpleRootedTree::yule(NUM_TAXA); - let mut t2 = SimpleRootedTree::yule(NUM_TAXA); + let mut t1 = SimpleRootedTree::yule(taxa_size); + let mut t2 = SimpleRootedTree::yule(taxa_size); t1.precompute_constant_time_lca(); t2.precompute_constant_time_lca(); let _ = t1.set_zeta(depth); @@ -58,25 +57,50 @@ fn benchmark_cophen_dist_naive(bencher: divan::Bencher) { }); } -#[divan::bench] -fn benchmark_rfs(bencher: divan::Bencher) { +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] +fn benchmark_rfs(bencher: divan::Bencher, taxa_size: usize) { bencher .with_inputs(|| { - let t1 = SimpleRootedTree::yule(NUM_TAXA); - let t2 = SimpleRootedTree::yule(NUM_TAXA); - (t1, t2) + let t1 = SimpleRootedTree::yule(taxa_size); + let t2 = SimpleRootedTree::yule(taxa_size); + + (t1,t2) }) .bench_refs(|(t1, t2)| { - let _ = t1.rfs(t2); + let _ = t1.rfs(&t2); + }); +} + +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] +fn benchmark_bps(bencher: divan::Bencher, taxa_size: usize) { + bencher + .with_inputs(|| { + let t1 = SimpleRootedTree::yule(taxa_size); + t1 + }) + .bench_refs(|t1| { + let _ = t1.get_bipartitions_ids().map(|(c1,c2)| (c1.map(|x| t1.get_node_taxa(x).unwrap()).collect_vec(), c2.map(|x| t1.get_node_taxa(x).unwrap()).collect_vec())).collect_vec(); + }); +} + +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] +fn benchmark_postord_ids(bencher: divan::Bencher, taxa_size: usize) { + bencher + .with_inputs(|| { + let t1 = SimpleRootedTree::yule(taxa_size); + t1 + }) + .bench_refs(|t1| { + let _ = t1.postord_ids(t1.get_root_id()).collect_vec(); }); } -#[divan::bench] -fn benchmark_ca(bencher: divan::Bencher) { +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] +fn benchmark_ca(bencher: divan::Bencher, taxa_size: usize) { bencher .with_inputs(|| { - let t1 = SimpleRootedTree::yule(NUM_TAXA); - let t2 = SimpleRootedTree::yule(NUM_TAXA); + let t1 = SimpleRootedTree::yule(taxa_size); + let t2 = SimpleRootedTree::yule(taxa_size); (t1, t2) }) .bench_refs(|(t1, t2)| { @@ -84,7 +108,7 @@ fn benchmark_ca(bencher: divan::Bencher) { }); } -#[divan::bench(args = [100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, 340, 350, 360, 370, 380, 390, 400, 410, 420, 430, 440, 450, 460, 470, 480, 490, 500, 510, 520, 530, 540, 550, 560, 570, 580, 590, 600, 610, 620, 630, 640, 650, 660, 670, 680, 690, 700, 710, 720, 730, 740, 750, 760, 770, 780, 790, 800, 810, 820, 830, 840, 850, 860, 870, 880, 890, 900, 910, 920, 930, 940, 950, 960, 970, 980, 990, 1000, 1010, 1020, 1030, 1040, 1050, 1060, 1070, 1080, 1090, 1100, 1110, 1120, 1130, 1140, 1150, 1160, 1170, 1180, 1190, 1200, 1210, 1220, 1230, 1240, 1250, 1260, 1270, 1280, 1290, 1300, 1310, 1320, 1330, 1340, 1350, 1360, 1370, 1380, 1390, 1400, 1410, 1420, 1430, 1440, 1450, 1460, 1470, 1480, 1490, 1500, 1510, 1520, 1530, 1540, 1550, 1560, 1570, 1580, 1590, 1600, 1610, 1620, 1630, 1640, 1650, 1660, 1670, 1680, 1690, 1700, 1710, 1720, 1730, 1740, 1750, 1760, 1770, 1780, 1790, 1800, 1810, 1820, 1830, 1840, 1850, 1860, 1870, 1880, 1890, 1900, 1910, 1920, 1930, 1940, 1950, 1960, 1970, 1980, 1990, 2000, 2010, 2020, 2030, 2040, 2050, 2060, 2070, 2080, 2090, 2100, 2110, 2120, 2130, 2140, 2150, 2160, 2170, 2180, 2190, 2200, 2210, 2220, 2230, 2240, 2250, 2260, 2270, 2280, 2290, 2300, 2310, 2320, 2330, 2340, 2350, 2360, 2370, 2380, 2390, 2400, 2410, 2420, 2430, 2440, 2450, 2460, 2470, 2480, 2490, 2500, 2510, 2520, 2530, 2540, 2550, 2560, 2570, 2580, 2590, 2600, 2610, 2620, 2630, 2640, 2650, 2660, 2670, 2680, 2690, 2700, 2710, 2720, 2730, 2740, 2750, 2760, 2770, 2780, 2790, 2800, 2810, 2820, 2830, 2840, 2850, 2860, 2870, 2880, 2890, 2900, 2910, 2920, 2930, 2940, 2950, 2960, 2970, 2980, 2990, 3000, 3010, 3020, 3030, 3040, 3050, 3060, 3070, 3080, 3090, 3100, 3110, 3120, 3130, 3140, 3150, 3160, 3170, 3180, 3190, 3200, 3210, 3220, 3230, 3240, 3250, 3260, 3270, 3280, 3290, 3300, 3310, 3320, 3330, 3340, 3350, 3360, 3370, 3380, 3390, 3400, 3410, 3420, 3430, 3440, 3450, 3460, 3470, 3480, 3490, 3500, 3510, 3520, 3530, 3540, 3550, 3560, 3570, 3580, 3590, 3600, 3610, 3620, 3630, 3640, 3650, 3660, 3670, 3680, 3690, 3700, 3710, 3720, 3730, 3740, 3750, 3760, 3770, 3780, 3790, 3800, 3810, 3820, 3830, 3840, 3850, 3860, 3870, 3880, 3890, 3900, 3910, 3920, 3930, 3940, 3950, 3960, 3970, 3980, 3990, 4000, 4010, 4020, 4030, 4040, 4050, 4060, 4070, 4080, 4090, 4100, 4110, 4120, 4130, 4140, 4150, 4160, 4170, 4180, 4190, 4200, 4210, 4220, 4230, 4240, 4250, 4260, 4270, 4280, 4290, 4300, 4310, 4320, 4330, 4340, 4350, 4360, 4370, 4380, 4390, 4400, 4410, 4420, 4430, 4440, 4450, 4460, 4470, 4480, 4490, 4500, 4510, 4520, 4530, 4540, 4550, 4560, 4570, 4580, 4590, 4600, 4610, 4620, 4630, 4640, 4650, 4660, 4670, 4680, 4690, 4700, 4710, 4720, 4730, 4740, 4750, 4760, 4770, 4780, 4790, 4800, 4810, 4820, 4830, 4840, 4850, 4860, 4870, 4880, 4890, 4900, 4910, 4920, 4930, 4940, 4950, 4960, 4970, 4980, 4990, 5000, 5010, 5020, 5030, 5040, 5050, 5060, 5070, 5080, 5090, 5100, 5110, 5120, 5130, 5140, 5150, 5160, 5170, 5180, 5190, 5200, 5210, 5220, 5230, 5240, 5250, 5260, 5270, 5280, 5290, 5300, 5310, 5320, 5330, 5340, 5350, 5360, 5370, 5380, 5390, 5400, 5410, 5420, 5430, 5440, 5450, 5460, 5470, 5480, 5490, 5500, 5510, 5520, 5530, 5540, 5550, 5560, 5570, 5580, 5590, 5600, 5610, 5620, 5630, 5640, 5650, 5660, 5670, 5680, 5690, 5700, 5710, 5720, 5730, 5740, 5750, 5760, 5770, 5780, 5790, 5800, 5810, 5820, 5830, 5840, 5850, 5860, 5870, 5880, 5890, 5900, 5910, 5920, 5930, 5940, 5950, 5960, 5970, 5980, 5990, 6000, 6010, 6020, 6030, 6040, 6050, 6060, 6070, 6080, 6090, 6100, 6110, 6120, 6130, 6140, 6150, 6160, 6170, 6180, 6190, 6200, 6210, 6220, 6230, 6240, 6250, 6260, 6270, 6280, 6290, 6300, 6310, 6320, 6330, 6340, 6350, 6360, 6370, 6380, 6390, 6400, 6410, 6420, 6430, 6440, 6450, 6460, 6470, 6480, 6490, 6500, 6510, 6520, 6530, 6540, 6550, 6560, 6570, 6580, 6590, 6600, 6610, 6620, 6630, 6640, 6650, 6660, 6670, 6680, 6690, 6700, 6710, 6720, 6730, 6740, 6750, 6760, 6770, 6780, 6790, 6800, 6810, 6820, 6830, 6840, 6850, 6860, 6870, 6880, 6890, 6900, 6910, 6920, 6930, 6940, 6950, 6960, 6970, 6980, 6990, 7000, 7010, 7020, 7030, 7040, 7050, 7060, 7070, 7080, 7090, 7100, 7110, 7120, 7130, 7140, 7150, 7160, 7170, 7180, 7190, 7200, 7210, 7220, 7230, 7240, 7250, 7260, 7270, 7280, 7290, 7300, 7310, 7320, 7330, 7340, 7350, 7360, 7370, 7380, 7390, 7400, 7410, 7420, 7430, 7440, 7450, 7460, 7470, 7480, 7490, 7500, 7510, 7520, 7530, 7540, 7550, 7560, 7570, 7580, 7590, 7600, 7610, 7620, 7630, 7640, 7650, 7660, 7670, 7680, 7690, 7700, 7710, 7720, 7730, 7740, 7750, 7760, 7770, 7780, 7790, 7800, 7810, 7820, 7830, 7840, 7850, 7860, 7870, 7880, 7890, 7900, 7910, 7920, 7930, 7940, 7950, 7960, 7970, 7980, 7990, 8000, 8010, 8020, 8030, 8040, 8050, 8060, 8070, 8080, 8090, 8100, 8110, 8120, 8130, 8140, 8150, 8160, 8170, 8180, 8190, 8200, 8210, 8220, 8230, 8240, 8250, 8260, 8270, 8280, 8290, 8300, 8310, 8320, 8330, 8340, 8350, 8360, 8370, 8380, 8390, 8400, 8410, 8420, 8430, 8440, 8450, 8460, 8470, 8480, 8490, 8500, 8510, 8520, 8530, 8540, 8550, 8560, 8570, 8580, 8590, 8600, 8610, 8620, 8630, 8640, 8650, 8660, 8670, 8680, 8690, 8700, 8710, 8720, 8730, 8740, 8750, 8760, 8770, 8780, 8790, 8800, 8810, 8820, 8830, 8840, 8850, 8860, 8870, 8880, 8890, 8900, 8910, 8920, 8930, 8940, 8950, 8960, 8970, 8980, 8990, 9000, 9010, 9020, 9030, 9040, 9050, 9060, 9070, 9080, 9090, 9100, 9110, 9120, 9130, 9140, 9150, 9160, 9170, 9180, 9190, 9200, 9210, 9220, 9230, 9240, 9250, 9260, 9270, 9280, 9290, 9300, 9310, 9320, 9330, 9340, 9350, 9360, 9370, 9380, 9390, 9400, 9410, 9420, 9430, 9440, 9450, 9460, 9470, 9480, 9490, 9500, 9510, 9520, 9530, 9540, 9550, 9560, 9570, 9580, 9590, 9600, 9610, 9620, 9630, 9640, 9650, 9660, 9670, 9680, 9690, 9700, 9710, 9720, 9730, 9740, 9750, 9760, 9770, 9780, 9790, 9800, 9810, 9820, 9830, 9840, 9850, 9860, 9870, 9880, 9890, 9900, 9910, 9920, 9930, 9940, 9950, 9960, 9970, 9980, 9990, 10000])] +#[divan::bench(args = [200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 2200, 2400, 2600, 2800, 3000, 3200, 3400, 3600, 3800, 4000, 4200, 4400, 4600, 4800, 5000, 5200, 5400, 5600, 5800, 6000, 6200, 6400, 6600, 6800, 7000, 7200, 7400, 7600, 7800, 8000, 8200, 8400, 8600, 8800, 9000, 9200, 9400, 9600, 9800, 10000])] fn benchmark_contract(bencher: divan::Bencher, taxa_size: usize) { bencher .with_inputs(|| { diff --git a/src/iter/node_iter.rs b/src/iter/node_iter.rs index e478cf7..6936684 100644 --- a/src/iter/node_iter.rs +++ b/src/iter/node_iter.rs @@ -1,12 +1,12 @@ #![allow(clippy::needless_lifetimes)] #[cfg(feature = "non_crypto_hash")] -use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use fxhash::FxHashMap as HashMap; #[cfg(not(feature = "non_crypto_hash"))] use std::collections::{HashMap, HashSet}; use std::{collections::VecDeque, ops::Index}; - +use vers_vecs::BitVec; use itertools::Itertools; use crate::{ @@ -510,74 +510,45 @@ pub trait Clusters: DFS + BFS + Sized { .enumerate() .map(|(idx, id)| (id, idx)) .collect(); - let leaf_ids_rev: HashMap> = - leaf_ids.iter().map(|(id, idx)| (*idx, *id)).collect(); - let mut bps: HashMap, Vec> = vec![].into_iter().collect(); - let mut skip_clusters: HashSet> = vec![].into_iter().collect(); - let mut keep_nodes: Vec> = vec![]; + let leaf_ids_rev: Vec> = + leaf_ids.iter().map(|(id,_)| *id).collect(); + let num_leaves = leaf_ids.len(); + let mut bps: HashMap, BitVec> = vec![].into_iter().collect(); for n_id in self.postord_ids(self.get_root_id()) { + let mut bp = BitVec::from_zeros(num_leaves); match self.is_leaf(n_id) { true => { - let bp: (HashSet>, HashSet>) = ( - vec![n_id].into_iter().collect(), - leaf_ids.keys().filter(|x| **x != n_id).copied().collect(), - ); - let mut binary_str = vec![false; leaf_ids.len()]; - for i in bp.0.iter() { - binary_str[*leaf_ids.get(i).unwrap()] = true; - } - if !skip_clusters.contains(&binary_str) - && !skip_clusters.contains(&binary_str.iter().map(|x| !x).collect_vec()) - { - skip_clusters.insert(binary_str.iter().map(|x| !x).collect_vec()); - keep_nodes.push(n_id); - } - bps.insert(n_id, binary_str); + bp.flip_bit(*leaf_ids.get(&n_id).unwrap()); + bps.insert(n_id, bp.clone()); } false => { - if n_id == self.get_root_id() { + if n_id==self.get_root_id(){ continue; } - let children_clusters = self + self .get_node_children_ids(n_id) .map(|x| bps.get(&x).unwrap()) - .collect_vec(); - let binary_str = (0..leaf_ids.len()) - .map(|x| { - let values = children_clusters - .iter() - .map(|y| y[x] as usize) - .collect_vec(); - values.iter().sum::() > 0 - }) - .collect_vec(); - if !skip_clusters.contains(&binary_str) - && !skip_clusters.contains(&binary_str.iter().map(|x| !x).collect_vec()) - { - skip_clusters.insert(binary_str.iter().map(|x| !x).collect_vec()); - keep_nodes.push(n_id); + .for_each(|x| {let _ = bp.apply_mask_or(x);}); + if !(self.get_node_parent_id(n_id)==Some(self.get_root_id())){ + bps.insert(n_id, bp); } - bps.insert(n_id, binary_str); } }; } - return keep_nodes.into_iter().map(move |n_id| { - let mut bp1 = vec![]; - let mut bp2 = vec![]; - bps.get(&n_id) - .unwrap() - .iter() - .enumerate() - .for_each(|(idx, x)| match x { + return bps.into_values().map(move |bit_bp| { + let mut bp1 = Vec::with_capacity(leaf_ids.len()); + let mut bp2 = Vec::with_capacity(leaf_ids.len()); + for idx in 0..bit_bp.len(){ + match bit_bp.is_bit_set(idx).unwrap() { true => { - bp1.push(leaf_ids_rev.get(&idx).unwrap().to_owned()); + bp1.push(leaf_ids_rev[idx].to_owned()); } false => { - bp2.push(leaf_ids_rev.get(&idx).unwrap().to_owned()); + bp2.push(leaf_ids_rev[idx].to_owned()); } - }); - + } + } (bp1.into_iter(), bp2.into_iter()) }); } diff --git a/src/tree.rs b/src/tree.rs index 36d720a..7cdfd75 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -25,6 +25,7 @@ mod simple_rooted_tree { use crate::node::{Node, NodeID}; use crate::prelude::*; use vers_vecs::BinaryRmq; + // use bitvec::prelude::*; #[cfg(feature = "non_crypto_hash")] use fxhash::FxHashMap as HashMap; diff --git a/src/tree/distances.rs b/src/tree/distances.rs index 4aa85aa..22b1c03 100644 --- a/src/tree/distances.rs +++ b/src/tree/distances.rs @@ -1,6 +1,7 @@ use itertools::Itertools; use num::{Float, NumCast, Signed}; use std::fmt::{Debug, Display}; +use vers_vecs::BitVec; #[cfg(feature = "non_crypto_hash")] use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; @@ -80,6 +81,7 @@ where { /// Returns Robinson Foulds distance between tree and self. fn rfs(&self, tree: &Self) -> usize { + let mut dist = 0; let mut all_taxa: HashSet<&TreeNodeMeta> = self.get_taxa_space().collect(); all_taxa.extend(tree.get_taxa_space()); let num_taxa = all_taxa.len(); @@ -88,40 +90,73 @@ where .enumerate() .map(|x| (x.1, x.0)) .collect(); + let mut self_bps: HashMap, BitVec> = vec![].into_iter().collect(); + let mut self_out_bps: HashSet = vec![].into_iter().collect(); + for n_id in self.postord_ids(self.get_root_id()) { + let mut bp = BitVec::from_zeros(num_taxa); + match self.is_leaf(n_id) { + true => { + let leaf_meta = self.get_node_taxa(n_id).unwrap(); + bp.flip_bit(*all_taxa_map.get(leaf_meta).unwrap()); + self_bps.insert(n_id, bp.clone()); + self_out_bps.insert(bp); + } + false => { + if n_id==self.get_root_id(){ + continue; + } + self.get_node_children_ids(n_id) + .map(|x| self_bps.get(&x).unwrap()) + .for_each(|x| {let _ = bp.apply_mask_or(x);}); + if !(self.get_node_parent_id(n_id)==Some(self.get_root_id())){ + self_out_bps.insert(bp.clone()); + } + self_bps.insert(n_id, bp); + } + }; + } - let self_bps: HashSet> = self - .get_clusters_ids() - .map(|(_, cluster)| { - let mut bit_str = vec![false; num_taxa]; - cluster - .map(|x| self.get_node_taxa(x).unwrap()) - .for_each(|x| bit_str[*all_taxa_map.get(x).unwrap()] = true); - bit_str - }) - .collect(); - let tree_bps: HashSet> = tree - .get_clusters_ids() - .map(|(_, cluster)| { - let mut bit_str = vec![false; num_taxa]; - cluster - .map(|x| tree.get_node_taxa(x).unwrap()) - .for_each(|x| bit_str[*all_taxa_map.get(x).unwrap()] = true); - bit_str - }) - .collect(); + let mut tree_bps: HashMap, BitVec> = vec![].into_iter().collect(); + let mut tree_out_bps: HashSet = vec![].into_iter().collect(); + for n_id in tree.postord_ids(tree.get_root_id()) { + let mut bp = BitVec::from_zeros(num_taxa); + match tree.is_leaf(n_id) { + true => { + let leaf_meta = tree.get_node_taxa(n_id).unwrap(); + bp.flip_bit(*all_taxa_map.get(leaf_meta).unwrap()); + tree_bps.insert(n_id, bp.clone()); + tree_out_bps.insert(bp.clone()); + } + false => { + if n_id==tree.get_root_id(){ + continue; + } + tree.get_node_children_ids(n_id) + .map(|x| tree_bps.get(&x).unwrap()) + .for_each(|x| {let _ = bp.apply_mask_or(x);}); + if !(tree.get_node_parent_id(n_id)==Some(tree.get_root_id())){ + tree_out_bps.insert(bp.clone()); + } + tree_bps.insert(n_id, bp.clone()); + } + }; + } - let mut dist = 0; - for i in self_bps.iter() { - if !tree_bps.contains(i) && !tree_bps.contains(&i.iter().map(|x| !x).collect_vec()) { + for i in self_out_bps.iter() { + let mut i_rev = BitVec::from_ones(num_taxa); + let _ = i_rev.apply_mask_xor(i); + if !tree_out_bps.contains(i) && !tree_out_bps.contains(&i_rev) { dist += 1; } } - for i in tree_bps.iter() { - if !self_bps.contains(i) && !self_bps.contains(&i.iter().map(|x| !x).collect_vec()) { + for i in tree_out_bps.iter() { + let mut i_rev = BitVec::from_ones(num_taxa); + let _ = i_rev.apply_mask_xor(i); + if !self_out_bps.contains(i) && !self_out_bps.contains(&i_rev) { dist += 1; } } - + dist / 2 } } diff --git a/tests/tree-tests.rs b/tests/tree-tests.rs index 72c6af0..271e751 100644 --- a/tests/tree-tests.rs +++ b/tests/tree-tests.rs @@ -1,4 +1,7 @@ -use std::collections::HashSet; +#[cfg(feature = "non_crypto_hash")] +use fxhash::FxHashSet as HashSet; +#[cfg(not(feature = "non_crypto_hash"))] +use std::collections::{HashMap, HashSet}; use itertools::Itertools; use phylo::node::Node; @@ -70,7 +73,7 @@ fn read_small_tree() { dbg!(format!("{}", &tree.to_newick())); assert_eq!( &tree.get_taxa_space().collect::>(), - &HashSet::from([&"A".to_string(), &"B".to_string(), &"C".to_string()]) + &vec![&"A".to_string(), &"B".to_string(), &"C".to_string()].into_iter().collect() ); let input_str = String::from("((A:1e-3,B:2e-3),C:6e-3);"); let tree = SimpleRootedTree::from_newick(input_str.as_bytes()).unwrap(); @@ -266,7 +269,7 @@ fn bipartitions() { ) }) .collect_vec(); - + let input_str: String = String::from("(A, (B, (C, (D, (E, (F, (G, H)))))));"); let t1 = SimpleRootedTree::from_newick(input_str.as_bytes()).unwrap(); let _bps = t1