diff --git a/global.json b/global.json
index f028edc..b7a1681 100644
--- a/global.json
+++ b/global.json
@@ -1,5 +1,5 @@
{
- "VERSION": "1.0.1",
+ "VERSION": "1.0.2",
"CONFIG_KEYS": ["Contact Network", "Transmission Network", "Sample Times", "Viral Phylogeny (Transmissions)", "Viral Phylogeny (Seeds)", "Mutation Rates", "Ancestral Sequence", "Sequence Evolution"],
"DESC": {
"Contact Network": "The Contact Network graph model describes all social interactions:
- Nodes represent individuals in the population
- Edges represent all interactions across which the pathogen can transmit
- Currently, FAVITES-Lite only supports static (i.e., unchanging) contact networks
",
diff --git a/plugins/common.py b/plugins/common.py
index 6ac53f4..199e16f 100644
--- a/plugins/common.py
+++ b/plugins/common.py
@@ -1,9 +1,18 @@
#! /usr/bin/env python3
+# standard imports
from datetime import datetime
from sys import stderr
import math
+
+# constants
ZERO_THRESH = 0.00000000001
+# non-standard imports
+try:
+ from scipy.stats import truncnorm
+except:
+ error("Unable to import scipy. Install with: pip install scipy")
+
# dummy plugin function
def DUMMY_PLUGIN_FUNC(params, out_fn, config, GLOBAL, verbose=True):
pass
@@ -28,3 +37,11 @@ def check_props(props):
return False
tot += p
return abs(tot - 1) <= ZERO_THRESH
+
+# sample from a truncated normal distribution with (non-truncated) mean `loc` and (non-truncated) stdev `scale` in range [`a`,`b`]
+# I'm using the Wikipedia notation: https://en.wikipedia.org/wiki/Truncated_normal_distribution
+# SciPy's `truncnorm` defines `a` and `b` as "standard deviations above/below `loc`", so I need to convert
+def truncnorm_rvs(loc, scale, a_min, b_max, size):
+ a = (a_min - loc) / scale
+ b = (b_max - loc) / scale
+ return truncnorm.rvs(a=a, b=b, loc=loc, scale=scale, size=size)
diff --git a/plugins/mutation_rates/common_treeswift.py b/plugins/mutation_rates/common_treeswift.py
index 492003b..e22a179 100644
--- a/plugins/mutation_rates/common_treeswift.py
+++ b/plugins/mutation_rates/common_treeswift.py
@@ -6,10 +6,6 @@
from numpy.random import f as f_dist
except:
error("Unable to import numpy. Install with: pip install numpy")
-try:
- from scipy.stats import truncnorm
-except:
- error("Unable to import scipy. Install with: pip install scipy")
try:
from treeswift import read_tree_newick
except:
@@ -216,9 +212,11 @@ def treeswift_triangular(params, out_fn, config, GLOBAL, verbose=True):
# Truncated Normal
def treeswift_truncnorm(params, out_fn, config, GLOBAL, verbose=True):
- tree = read_tree_newick(out_fn['viral_phylogeny_time']); mu = params['mu']; sigma = params['sigma']; a = params['a']; b = params['b']
+ mu = params['mu']; sigma = params['sigma']; a_min = params['a']; b_max = params['b']
+ tree = read_tree_newick(out_fn['viral_phylogeny_time'])
nodes = [node for node in tree.traverse_preorder() if node.edge_length is not None]
- rates = truncnorm.rvs(a=a, b=b, loc=mu, scale=sigma, size=len(nodes))
+ rates = truncnorm_rvs(loc=mu, scale=sigma, a_min=a_min, b_max=b_max, size=len(nodes))
+ print('\n'.join(str(r) for r in rates)); exit() # TODO
for i in range(len(nodes)):
nodes[i].edge_length *= rates[i]
tree.write_tree_newick(out_fn['viral_phylogeny_mut'])
diff --git a/plugins/sample_times/time_windows.py b/plugins/sample_times/time_windows.py
index 6752558..6376866 100644
--- a/plugins/sample_times/time_windows.py
+++ b/plugins/sample_times/time_windows.py
@@ -6,7 +6,7 @@
except:
error("Unable to import numpy. Install with: pip install numpy")
try:
- from scipy.stats import truncexpon, truncnorm
+ from scipy.stats import truncexpon
except:
error("Unable to import scipy. Install with: pip install scipy")
@@ -28,8 +28,7 @@ def time_windows(model, params, out_fn, config, GLOBAL, verbose=True):
if model == "Truncated Exponential":
variates = list(truncexpon.rvs(1, size=tot_num_samples))
elif model == "Truncated Normal":
- corrected_min = (0-params['mu'])/params['sigma']; corrected_max = (1-params['mu'])/params['sigma']
- variates = list(truncnorm.rvs(corrected_min, corrected_max, loc=params['mu'], scale=params['sigma'], size=tot_num_samples))
+ variates = list(truncnorm_rvs(loc=params['mu'], scale=params['sigma'], a_min=0, b_max=1, size=tot_num_samples))
for node in windows:
for _ in range(params['num_samples']):
state, start, end = choice(windows[node]); length = end - start; delta = None