Skip to content

Commit

Permalink
Wrote a load of code in lib.rs that generalises applying move functio…
Browse files Browse the repository at this point in the history
…ns to work with moves on the topoology and the rate matrix
  • Loading branch information
jhellewell14 committed Nov 29, 2024
1 parent add74da commit 49e0bb4
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 82 deletions.
8 changes: 8 additions & 0 deletions src/branchlength.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use crate::Topology;

impl Topology {
pub fn update_branchlength(mut self, index: usize) {


}
}
150 changes: 149 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod topology;
mod genetic_data;
mod moves;
mod state_data;
mod branchlength;

use rate_matrix::RateMatrix;
use state_data::create_dummy_statedata;
Expand All @@ -21,7 +22,9 @@ use crate::genetic_data::*;
use crate::moves::*;
use rand::Rng;
use crate::iterators::Handedness;
use crate::rate_matrix::update_matrix;
// use crate::rate_matrix::update_matrix;
use ndarray::s;
use std::collections::HashMap;

pub fn main() {
let args = cli_args();
Expand All @@ -45,6 +48,151 @@ pub fn main() {
let mge_mat = na::Matrix2::new(0.4, 0.6, 0.6, 0.4);
let mut st = create_dummy_statedata(1, &t, &mge_mat);

let nodes: Vec<usize> = t.postorder(t.get_root()).map(|n| n.get_id()).collect();
for i in nodes {
let old_len = t.nodes[i].get_branchlen();
t.nodes[i].set_branchlen(old_len + 1.0);
};

pub struct TreeState<R: RateMatrix>{
top: Topology,
mat: R,
ll: Option<f64>,
changed_nodes: Option<Vec<usize>>,
}

pub trait TreeMove<R: RateMatrix> {
fn generate(&self, ts: &TreeState<R>) -> TreeState<R>;
}

pub struct MatrixMove {}

impl<R: RateMatrix> TreeMove<R> for MatrixMove {
fn generate(&self, ts: &TreeState<R>) -> TreeState<R> {
let rm = ts.mat.matrix_move();
let changes: Vec<usize> = ts.top.postorder_notips(ts.top.get_root()).map(|n| n.get_id()).collect();
// This is not ideal
let new_top = Topology{
nodes: ts.top.nodes.clone(),
tree_vec: ts.top.tree_vec.clone(),
likelihood: ts.top.likelihood,
};

TreeState{
top: new_top,
mat: rm,
ll: ts.ll,
changed_nodes: Some(changes),
}
}
}

impl<R:RateMatrix> TreeMove<R> for ExactMove {
fn generate(&self, ts: &TreeState<R>) -> TreeState<R> {
let new_topology = Topology::from_vec(&self.target_vector);
let changes: Option<Vec<usize>> = ts.top.find_changes(&new_topology);
let mat = ts.mat;
TreeState{
top: new_topology,
mat: mat,
ll: ts.ll,
changed_nodes: changes,
}
}
}

impl<R: RateMatrix> TreeState<R> {

pub fn likelihood(&self, gen_data: &ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>) -> f64 {
let root_likelihood = gen_data.slice(s![self.top.get_root().get_id(), .., .. ]);

root_likelihood
.rows()
.into_iter()
.fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT))
}


pub fn apply_move<T: TreeMove<R>>(mut self,
move_fn: T,
accept_fn: fn(&f64, &f64) -> bool,
gen_data: &mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>) -> TreeState<R> {

if self.ll.is_none() {
self.ll = Some(self.likelihood(gen_data));
}
let old_ll = self.ll.unwrap();

let rate_mat = self.mat.get_matrix();
let new_ts = move_fn.generate(&self);

// If move did nothing, return old TreeState
if new_ts.changed_nodes.is_none() {
return self
}

// Do minimal likelihood updates (and push new values into HashMap temporarily)
let nodes = new_ts.top.changes_iter(new_ts.changed_nodes.unwrap());
let mut temp_likelihoods: HashMap<usize, ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>> = HashMap::new();

for node in nodes {
// check if in HM
let lchild = node.get_lchild().unwrap();
let rchild = node.get_rchild().unwrap();
let seql: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;
let seqr: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;

match (temp_likelihoods.contains_key(&lchild), temp_likelihoods.contains_key(&rchild)) {
(true, true) => {
seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
},
(true, false) => {
seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
seqr = slice_data(rchild, &gen_data);
},
(false, true) => {
seql = slice_data(lchild, &gen_data);
seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
},
(false, false) => {
seql = slice_data(lchild, &gen_data);
seqr = slice_data(rchild, &gen_data);
},
};

let node_ll = node_likelihood(seql, seqr,
&matrix_exp(&rate_mat, new_ts.top.nodes[lchild].get_branchlen()),
&matrix_exp(&rate_mat, new_ts.top.nodes[rchild].get_branchlen()));

temp_likelihoods.insert(node.get_id(), node_ll);
}

// Calculate whole new topology likelihood at root
let new_ll = temp_likelihoods
.get(&new_ts.top.get_root().get_id())
.unwrap()
.rows()
.into_iter()
.fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT));

// Likelihood decision rule
if accept_fn(&old_ll, &new_ll) {
// Drain hashmap into gen_data
for (i, ll_data) in temp_likelihoods.drain() {
gen_data.slice_mut(s![i, .., ..]).assign(&ll_data);
}
// Update Topology
self.top.nodes = new_ts.top.nodes;
self.top.tree_vec = new_ts.top.tree_vec;
self.mat = new_ts.mat;
self.ll = Some(new_ll);
};

self

}
}
// let mut pp = rate_matrix::GTR::default();
// println!("{:?}", pp.get_matrix());
// update_matrix(&mut t, always_accept, &mut gen_data, &mut pp);
Expand Down
162 changes: 81 additions & 81 deletions src/rate_matrix.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::topology::Topology;
use crate::{likelihood, slice_data, node_likelihood, matrix_exp, BF_DEFAULT, base_freq_logse};
use crate::{base_freq_logse, likelihood, matrix_exp, node_likelihood, slice_data, CandidateTopology, MoveFn, BF_DEFAULT};
use statrs::distribution::{Dirichlet};
use rand::distributions::{Distribution, Uniform};
use std::collections::HashMap;
Expand All @@ -14,7 +14,7 @@ pub trait RateMatrix: Copy {

fn get_params(&self) -> Vec<f64>;

fn matrix_move(&mut self) -> Self;
fn matrix_move(&self) -> Self;
}

#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -77,7 +77,7 @@ impl RateMatrix for GTR {
-(self.c * self.p0 + self.e * self.p1 + self.f * self.p2));
}

fn matrix_move(&mut self) -> Self {
fn matrix_move(&self) -> Self {
let mut d1 = Dirichlet::new_with_param(1.0, 6).unwrap();
let pars = d1.sample(&mut rand::thread_rng());

Expand Down Expand Up @@ -151,7 +151,7 @@ impl RateMatrix for JC69 {
);
}

fn matrix_move(&mut self) -> Self {
fn matrix_move(&self) -> Self {
let rng = rand::thread_rng();
let dist = Uniform::new(0.0, 1.0);
let params = vec![dist.sample(&mut rand::thread_rng())];
Expand Down Expand Up @@ -200,80 +200,80 @@ impl MGE {
}
}

pub fn update_matrix<T: RateMatrix>(topology: &mut Topology,
accept_fn: fn(&f64, &f64) -> bool,
gen_data: &mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>,
rate_matrix: &mut T) -> () {

// Get current likelihood, calculating if needed
if topology.likelihood.is_none() {
topology.likelihood = Some(likelihood(&topology, gen_data));
}
let old_ll = topology.likelihood.unwrap();
println!("old ll: {:?}", old_ll);
// Generate new matrix
let new_mat = rate_matrix.matrix_move();

// Iterator over internal nodes
let nodes = topology.postorder_notips(topology.get_root());
// HashMap for potentially temporary likelihood calculations
let mut temp_likelihoods: HashMap<usize, ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>> = HashMap::new();

// Update likelihood at internal nodes
for node in nodes {
// check if in HM
let lchild = node.get_lchild().unwrap();
let rchild = node.get_rchild().unwrap();
let seql: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;
let seqr: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;

match (temp_likelihoods.contains_key(&lchild), temp_likelihoods.contains_key(&rchild)) {
(true, true) => {
seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
},
(true, false) => {
seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
seqr = slice_data(rchild, &gen_data);
},
(false, true) => {
seql = slice_data(lchild, &gen_data);
seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
},
(false, false) => {
seql = slice_data(lchild, &gen_data);
seqr = slice_data(rchild, &gen_data);
},
};

let node_ll = node_likelihood(seql, seqr,
&matrix_exp(&new_mat.get_matrix(), topology.nodes[lchild].get_branchlen()),
&matrix_exp(&new_mat.get_matrix(), topology.nodes[rchild].get_branchlen()));

temp_likelihoods.insert(node.get_id(), node_ll);
}

// Calculate whole new topology likelihood at root
let new_ll = temp_likelihoods
.get(&topology.get_root().get_id())
.unwrap()
.rows()
.into_iter()
.fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT));

println!("{:?}", new_mat.get_matrix());
println!("new ll: {:?}", new_ll);

// Likelihood decision rule
if accept_fn(&old_ll, &new_ll) {
// Drain hashmap into gen_data
for (i, ll_data) in temp_likelihoods.drain() {
gen_data.slice_mut(s![i, .., ..]).assign(&ll_data);
}
// Update likelihood
topology.likelihood = Some(new_ll);
rate_matrix.update_params(new_mat.get_params());
rate_matrix.update_matrix();
}

}
// pub fn update_matrix<T: RateMatrix>(topology: &mut Topology,
// accept_fn: fn(&f64, &f64) -> bool,
// gen_data: &mut ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 3]>>,
// rate_matrix: &mut T) -> () {

// // Get current likelihood, calculating if needed
// if topology.likelihood.is_none() {
// topology.likelihood = Some(likelihood(&topology, gen_data));
// }
// let old_ll = topology.likelihood.unwrap();
// println!("old ll: {:?}", old_ll);
// // Generate new matrix
// let new_mat = rate_matrix.matrix_move();

// // Iterator over internal nodes
// let nodes = topology.postorder_notips(topology.get_root());
// // HashMap for potentially temporary likelihood calculations
// let mut temp_likelihoods: HashMap<usize, ndarray::ArrayBase<ndarray::OwnedRepr<f64>, ndarray::Dim<[usize; 2]>>> = HashMap::new();

// // Update likelihood at internal nodes
// for node in nodes {
// // check if in HM
// let lchild = node.get_lchild().unwrap();
// let rchild = node.get_rchild().unwrap();
// let seql: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;
// let seqr: ndarray::ArrayBase<ndarray::ViewRepr<&f64>, ndarray::Dim<[usize; 2]>>;

// match (temp_likelihoods.contains_key(&lchild), temp_likelihoods.contains_key(&rchild)) {
// (true, true) => {
// seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
// seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
// },
// (true, false) => {
// seql = temp_likelihoods.get(&lchild).unwrap().slice(s![.., ..]);
// seqr = slice_data(rchild, &gen_data);
// },
// (false, true) => {
// seql = slice_data(lchild, &gen_data);
// seqr = temp_likelihoods.get(&rchild).unwrap().slice(s![.., ..]);
// },
// (false, false) => {
// seql = slice_data(lchild, &gen_data);
// seqr = slice_data(rchild, &gen_data);
// },
// };

// let node_ll = node_likelihood(seql, seqr,
// &matrix_exp(&new_mat.get_matrix(), topology.nodes[lchild].get_branchlen()),
// &matrix_exp(&new_mat.get_matrix(), topology.nodes[rchild].get_branchlen()));

// temp_likelihoods.insert(node.get_id(), node_ll);
// }

// // Calculate whole new topology likelihood at root
// let new_ll = temp_likelihoods
// .get(&topology.get_root().get_id())
// .unwrap()
// .rows()
// .into_iter()
// .fold(0.0, |acc, base | acc + base_freq_logse(base, BF_DEFAULT));

// println!("{:?}", new_mat.get_matrix());
// println!("new ll: {:?}", new_ll);

// // Likelihood decision rule
// if accept_fn(&old_ll, &new_ll) {
// // Drain hashmap into gen_data
// for (i, ll_data) in temp_likelihoods.drain() {
// gen_data.slice_mut(s![i, .., ..]).assign(&ll_data);
// }
// // Update likelihood
// topology.likelihood = Some(new_ll);
// rate_matrix.update_params(new_mat.get_params());
// rate_matrix.update_matrix();
// }

// }

0 comments on commit 49e0bb4

Please sign in to comment.