From bdb3990fdd613e9fb982839ccf92d8ad67686c6d Mon Sep 17 00:00:00 2001 From: Niema Moshiri Date: Fri, 3 May 2024 09:49:49 -0700 Subject: [PATCH] Fixed bug in mutation rate Truncated-Normal model --- global.json | 2 +- plugins/common.py | 17 +++++++++++++++++ plugins/mutation_rates/common_treeswift.py | 10 ++++------ plugins/sample_times/time_windows.py | 5 ++--- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/global.json b/global.json index f028edc..b7a1681 100644 --- a/global.json +++ b/global.json @@ -1,5 +1,5 @@ { - "VERSION": "1.0.1", + "VERSION": "1.0.2", "CONFIG_KEYS": ["Contact Network", "Transmission Network", "Sample Times", "Viral Phylogeny (Transmissions)", "Viral Phylogeny (Seeds)", "Mutation Rates", "Ancestral Sequence", "Sequence Evolution"], "DESC": { "Contact Network": "The Contact Network graph model describes all social interactions:", diff --git a/plugins/common.py b/plugins/common.py index 6ac53f4..199e16f 100644 --- a/plugins/common.py +++ b/plugins/common.py @@ -1,9 +1,18 @@ #! /usr/bin/env python3 +# standard imports from datetime import datetime from sys import stderr import math + +# constants ZERO_THRESH = 0.00000000001 +# non-standard imports +try: + from scipy.stats import truncnorm +except: + error("Unable to import scipy. Install with: pip install scipy") + # dummy plugin function def DUMMY_PLUGIN_FUNC(params, out_fn, config, GLOBAL, verbose=True): pass @@ -28,3 +37,11 @@ def check_props(props): return False tot += p return abs(tot - 1) <= ZERO_THRESH + +# sample from a truncated normal distribution with (non-truncated) mean `loc` and (non-truncated) stdev `scale` in range [`a`,`b`] +# I'm using the Wikipedia notation: https://en.wikipedia.org/wiki/Truncated_normal_distribution +# SciPy's `truncnorm` defines `a` and `b` as "standard deviations above/below `loc`", so I need to convert +def truncnorm_rvs(loc, scale, a_min, b_max, size): + a = (a_min - loc) / scale + b = (b_max - loc) / scale + return truncnorm.rvs(a=a, b=b, loc=loc, scale=scale, size=size) diff --git a/plugins/mutation_rates/common_treeswift.py b/plugins/mutation_rates/common_treeswift.py index 492003b..e22a179 100644 --- a/plugins/mutation_rates/common_treeswift.py +++ b/plugins/mutation_rates/common_treeswift.py @@ -6,10 +6,6 @@ from numpy.random import f as f_dist except: error("Unable to import numpy. Install with: pip install numpy") -try: - from scipy.stats import truncnorm -except: - error("Unable to import scipy. Install with: pip install scipy") try: from treeswift import read_tree_newick except: @@ -216,9 +212,11 @@ def treeswift_triangular(params, out_fn, config, GLOBAL, verbose=True): # Truncated Normal def treeswift_truncnorm(params, out_fn, config, GLOBAL, verbose=True): - tree = read_tree_newick(out_fn['viral_phylogeny_time']); mu = params['mu']; sigma = params['sigma']; a = params['a']; b = params['b'] + mu = params['mu']; sigma = params['sigma']; a_min = params['a']; b_max = params['b'] + tree = read_tree_newick(out_fn['viral_phylogeny_time']) nodes = [node for node in tree.traverse_preorder() if node.edge_length is not None] - rates = truncnorm.rvs(a=a, b=b, loc=mu, scale=sigma, size=len(nodes)) + rates = truncnorm_rvs(loc=mu, scale=sigma, a_min=a_min, b_max=b_max, size=len(nodes)) + print('\n'.join(str(r) for r in rates)); exit() # TODO for i in range(len(nodes)): nodes[i].edge_length *= rates[i] tree.write_tree_newick(out_fn['viral_phylogeny_mut']) diff --git a/plugins/sample_times/time_windows.py b/plugins/sample_times/time_windows.py index 6752558..6376866 100644 --- a/plugins/sample_times/time_windows.py +++ b/plugins/sample_times/time_windows.py @@ -6,7 +6,7 @@ except: error("Unable to import numpy. Install with: pip install numpy") try: - from scipy.stats import truncexpon, truncnorm + from scipy.stats import truncexpon except: error("Unable to import scipy. Install with: pip install scipy") @@ -28,8 +28,7 @@ def time_windows(model, params, out_fn, config, GLOBAL, verbose=True): if model == "Truncated Exponential": variates = list(truncexpon.rvs(1, size=tot_num_samples)) elif model == "Truncated Normal": - corrected_min = (0-params['mu'])/params['sigma']; corrected_max = (1-params['mu'])/params['sigma'] - variates = list(truncnorm.rvs(corrected_min, corrected_max, loc=params['mu'], scale=params['sigma'], size=tot_num_samples)) + variates = list(truncnorm_rvs(loc=params['mu'], scale=params['sigma'], a_min=0, b_max=1, size=tot_num_samples)) for node in windows: for _ in range(params['num_samples']): state, start, end = choice(windows[node]); length = end - start; delta = None