Skip to content

Commit

Permalink
Added quadratic time update function that doesn't require leaf permut…
Browse files Browse the repository at this point in the history
…ations. Tidied up tests to remove ones that were not working. Added a test to check that quadratic update works the same as linear update.
  • Loading branch information
jhellewell14 committed Mar 12, 2024
1 parent 8cfe718 commit b9ff8ed
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 112 deletions.
41 changes: 30 additions & 11 deletions src/dspsa.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Sub;

use rand::Rng;
use crate::Tree;

Expand All @@ -6,7 +8,7 @@ pub fn phi(v: &[f64]) -> Vec<f64> {
if i == 0 || value.lt(&0.0) {
0.0
} else if value.gt(&(i as f64)) {
(i as f64) - 0.001
(i as f64) - 0.000001
} else {
*value
}
Expand All @@ -31,7 +33,7 @@ pub fn peturbation_vec(n: usize) -> Vec<f64> {
}

pub fn theta_change(pivec: &Vec<f64>, delta: &Vec<f64>, plus: bool) -> Vec<usize> {

let zip = pivec.iter().zip(delta.iter());

match plus {
Expand All @@ -50,17 +52,23 @@ pub fn theta_change(pivec: &Vec<f64>, delta: &Vec<f64>, plus: bool) -> Vec<usize

impl Tree {
pub fn optimise(&mut self, q: &na::Matrix4<f64>, iterations: usize) {

// Update likelihood if not done already
if self.get_tree_likelihood().eq(&0.0) {
self.update_likelihood(&q);
}

// Convert tree vector to Vec<f64>
let mut theta: Vec<f64> = self.tree_vec.iter().map(|x| *x as f64).collect();
println!("Current tree vector is: {:?}", self.tree_vec);
println!("Current likelihood is: {}", - self.get_tree_likelihood());
println!("Current likelihood is: {}", self.get_tree_likelihood());
let n: usize = theta.len();

// Tuning parameters for optimisation, will
// eventually have defaults or be passed in
let a: f64 = 2.0;
let cap_a: f64 = 2.0;
let alpha: f64 = 0.75;
let a: f64 = 1.5;
let cap_a: f64 = 10000.0;
let alpha: f64 = 0.51;

// Pre-allocate vectors
let mut delta: Vec<f64> = Vec::with_capacity(n);
Expand All @@ -72,46 +80,57 @@ impl Tree {
// Optimisation loop
for k in 0..=iterations {
println!("Optimisation step {} out of {}", k, iterations);
println!("Tree likelihood: {}", self.get_tree_likelihood());
// Generate peturbation vector
delta = peturbation_vec(n);
// println!("Peturbation vector: {:?}", delta);

// Generate pi vector
pivec = piv(&theta);
// println!("Pi vector: {:?}", pivec);

// Calculate theta+ and theta-,
// New tree vectors based on peturbation
thetaplus = theta_change(&pivec, &delta, true);
thetaminus = theta_change(&pivec, &delta, false);
// println!("theta+: {:?}", thetaplus);
// println!("theta-: {:?}", thetaminus);

// Update tree and calculate likelihoods
self.update_tree(Some(thetaplus), false);
self.update_likelihood(&q);
let lplus: f64 = - self.get_tree_likelihood();
let lplus: f64 = -self.get_tree_likelihood();

self.update_tree(Some(thetaminus), false);
self.update_likelihood(&q);
let lminus: f64 = - self.get_tree_likelihood();
let lminus: f64 = -self.get_tree_likelihood();

// Update theta based on likelihoods of theta+/-
let ldiff = lplus - lminus;

// let ghat: Vec<f64> = delta.iter()
// .map(|el| if !el.eq(&0.0) {el * ldiff} else {0.0}).collect();
println!("ll+ is {} and ll- is {}, ldiff is {}", lplus, lminus, ldiff);

ghat = delta.iter().map(|delta| ldiff * (1.0 / delta)).collect();
ghat[0] = 0.0;

// println!("ghat is {:?}", ghat);

let ak: f64 = a / (1.0 + cap_a + k as f64).powf(alpha);

println!("ak is {}", ak);

// Set new theta
theta = theta.iter().zip(ghat.iter())
.map(|(theta, g)| *theta - ak * g).collect();

// println!("New theta is: {:?}", theta);
}

// Update final tree after finishing optimisation
let new_tree_vec: Vec<usize> = phi(&theta).iter().map(|x| x.round() as usize).collect();
println!("New tree vector is: {:?}", new_tree_vec);
self.update_tree(Some(new_tree_vec), false);
self.update_likelihood(&q);
println!("New tree likelihood is {}", - self.get_tree_likelihood());
println!("New tree likelihood is {}", self.get_tree_likelihood());
}
}
12 changes: 6 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ pub fn main() {
1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0, -1.0,
);

// let mut tr = phylo2vec_quad(vec![0, 0, 0]);
let mut tr = phylo2vec_quad(random_tree(27));

let end = Instant::now();
// let filename = "listeria0.aln";
tr.add_genetic_data(&args.alignment);

tr.update_likelihood_postorder(&q);

// println!("{:?}", tr.mutation_lists);
// println!("{}", tr.get_tree_likelihood());
// println!("{:?}", tr.newick());
// println!("{:?}", tr.tree_vec);
println!("{}", tr.get_tree_likelihood());
println!("{:?}", tr.newick());
println!("{:?}", tr.tree_vec);

if !args.no_optimise {
tr.optimise(&q, 10);
// tr.optimise(&q, 100);
}

// let end = Instant::now();
Expand Down
18 changes: 0 additions & 18 deletions src/likelihoods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,6 @@ impl Tree {
}

impl Mutation {
pub fn prod(self, r: Mutation) -> Mutation {
Mutation(
self.0 * r.0,
self.1 * r.1,
self.2 * r.2,
self.3 * r.3,
)
}

pub fn sum(self, r: Mutation) -> Mutation {
Mutation(
Expand All @@ -105,12 +97,6 @@ impl Mutation {
)
}

pub fn child_likelihood(self, prob_matrix: &na::Matrix4<f64>) -> Mutation {
let x = prob_matrix * na::Vector4::new(self.0, self.1, self.2, self.3);

Mutation(x[0], x[1], x[2], x[3])
}

pub fn child_log_likelihood(self, prob_matrix: &na::Matrix4<f64>) -> Self {
let lnx = vec![self.0, self.1, self.2, self.3];
let mut x: Vec<f64> = Vec::new();
Expand All @@ -123,10 +109,6 @@ impl Mutation {
}

pub fn logse(x: Vec<f64>) -> f64 {
// if x.iter().all(|el| el.eq(&NEG_INFINITY)) {
// NEG_INFINITY
// } else {
let xstar = x.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
xstar + x.iter().fold(0.0,|acc, el| acc + f64::exp(el - xstar)).ln()
// }
}
31 changes: 31 additions & 0 deletions src/phylo2vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,35 @@ impl Tree {
self.add(M[[i, 1]], Some(M[[i, 2]]));
}
}

pub fn update_quad(&mut self, new_vec: Vec<usize>) {

// if !self.changes.is_empty() {
// panic!("There are already changes that need updating");
// }

let new_tree: Tree = phylo2vec_quad(new_vec);
let k: usize = new_tree.nodes.len();
let mut old_parent: Option<usize>;
let mut new_parent: Option<usize>;

for i in (0..k).rev() {
old_parent = self.get_node(i).unwrap().parent;
new_parent = new_tree.get_node(i).unwrap().parent;

if old_parent.ne(&new_parent) {
let d = new_tree.get_node(i).unwrap().depth;

match self.changes.get(&d) {
None => {self.changes.insert(d, vec![i]);},
Some(_) => {self.changes.get_mut(&d).unwrap().push(i);}
}
}
}

self.tree_vec = new_tree.tree_vec;
self.nodes = new_tree.nodes;

}

}
103 changes: 26 additions & 77 deletions src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ mod tests {
use crate::phylo2vec::phylo2vec_lin;
use crate::phylo2vec::phylo2vec_quad;
use crate::tree::Tree;
// use crate::import::str2tree;
// use crate::gen_list::Entry;
// use crate::gen_list::MutationType;

#[test]
fn treemake_quad() {
Expand Down Expand Up @@ -131,58 +128,31 @@ mod tests {

}

// #[test]
// fn likelihood_multiplication_machinery() {
// let muts = Mutation(0.15, 0.5, 0.25, 0.1);

// let q: na::Matrix4<f64> = na::Matrix4::new(
// -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0,
// );

// let time = 0.75;

// let p = na::Matrix::exp(&(q * time));

// assert_eq!(p[(0, 0)], 0.6082994225745668);
// assert_eq!(p[(1, 2)], 0.5029001980127024);
// assert_eq!(p[(2, 1)], 0.5029001980127025);
// assert_eq!(p[(3, 3)], 0.6082994225745667);

// let ll = muts.child_likelihood(&p);

// assert_eq!(ll.1, 0.5187100816969821);
// assert_eq!(ll.3, 0.5292500041531686);

// // Check matrix multiplication works as expected
// assert_eq!(
// muts.0 * p[(0, 0)] + muts.1 * p[(0, 1)] + muts.2 * p[(0, 2)] + muts.3 * p[(0, 3)],
// ll.0
// );
// assert_eq!(
// muts.0 * p[(1, 0)] + muts.1 * p[(1, 1)] + muts.2 * p[(1, 2)] + muts.3 * p[(1, 3)],
// ll.1
// );
// assert_eq!(
// muts.0 * p[(2, 0)] + muts.1 * p[(2, 1)] + muts.2 * p[(2, 2)] + muts.3 * p[(2, 3)],
// ll.2
// );
// assert_eq!(
// muts.0 * p[(3, 0)] + muts.1 * p[(3, 1)] + muts.2 * p[(3, 2)] + muts.3 * p[(3, 3)],
// ll.3
// );

// // Check outcome of multiplying likelihoods from two child nodes
// let muts2 = Mutation(0.3, 0.1, 0.3, 0.1);
// let ll2 = muts2.child_likelihood(&p);

// let outcome = ll.prod(ll2);

// assert_eq!(outcome.0, ll.0 * ll2.0);
// assert_eq!(outcome.1, ll.1 * ll2.1);
// assert_eq!(outcome.2, ll.2 * ll2.2);
// assert_eq!(outcome.3, ll.3 * ll2.3);
// }
#[test]
fn update_tree_quad_check() {
let mut tree_q = phylo2vec_quad(vec![0, 1, 0]);
let mut tree_l = phylo2vec_lin(vec![0, 0, 0], false);

let vecs: Vec<Vec<usize>> = vec![vec![0, 0, 0], vec![0, 1, 0], vec![0, 1, 2], vec![0, 1, 1]];

for vec in vecs {
let v = vec.clone();
tree_q = phylo2vec_quad(v);
tree_l.update_quad(vec);

for i in 0..=6 {
assert_eq!(
tree_l.get_node(i).unwrap().parent,
tree_q.get_node(i).unwrap().parent);
assert_eq!(
tree_l.get_node(i).unwrap().index,
tree_q.get_node(i).unwrap().index
);
}
}

}

#[test]
fn likelihood_internal_consistency() {
let q: na::Matrix4<f64> = na::Matrix4::new(
Expand Down Expand Up @@ -228,38 +198,17 @@ mod tests {

let old_likelihood = tr.get_tree_likelihood();

tr.update_tree(Some(vec![0, 0, 1]), false);
tr.update_quad(vec![0, 0, 1]);
tr.update_likelihood(&q);

tr.update_tree(Some(vec![0, 0, 0]), false);
tr.update_quad(vec![0, 0, 0]);
tr.update_likelihood(&q);

let new_likelihood = tr.get_tree_likelihood();

assert_eq!(old_likelihood, new_likelihood);
}

// #[test]
// fn likelihood_value_correct () {
// let q: na::Matrix4<f64> = na::Matrix4::new(
// -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0, 1.0, 1.0, 1.0, 1.0, -3.0,
// );

// let mut tr = phylo2vec_quad(vec![0; 1]);

// let genetic_data = vec![vec![Mutation(f64::ln(0.0), 0.0, f64::ln(0.0), f64::ln(0.0))],
// vec![Mutation(f64::ln(0.0), f64::ln(0.0), f64::ln(0.0), 0.0)],
// vec![]]; // This is the likelihood at the only internal (root) node, it can't be empty but will be overwritten

// tr.mutation_lists = genetic_data;

// tr.update_likelihood_postorder(&q);

// let likelihood = tr.get_tree_likelihood();

// assert_eq!(-2.773571667625644, likelihood);
// }

#[test]
fn newick_test () {
let mut tr = phylo2vec_quad(vec![0, 0, 0]);
Expand Down

0 comments on commit b9ff8ed

Please sign in to comment.