Skip to content

Commit

Permalink
Poisson context likelihood and better cli flexibility (#126)
Browse files Browse the repository at this point in the history
* Add poisson likelihood functions

* working poisson context likelihood

* revamp filter method

* format and lint

* most tests passing

* add context likelihood test

* testing tweaks and format

* update docs, format, and lint
  • Loading branch information
willdumm authored Apr 8, 2024
1 parent ff0890c commit e2910b7
Show file tree
Hide file tree
Showing 9 changed files with 375 additions and 186 deletions.
4 changes: 2 additions & 2 deletions docs/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ This file may be manipulated using ``gctree infer``, instead of providing
a dnapars ``outfile``.

.. note::
Although described below, using mutability parsimony or isotype parsimony
Although described below, using context likelihood, mutability parsimony, or isotype parsimony
as ranking criteria is experimental, and has not yet been shown in a careful
validation to improve tree inference. Only the default branching process
likelihood is recommended for tree ranking!
Expand All @@ -117,7 +117,7 @@ between trees. Providing arguments ``--isotype_mapfile`` and
arguments ``--mutability`` and ``--substitution`` allows trees to be ranked
according to a context-sensitive mutation model. By default, trees are ranked
lexicographically, first maximizing likelihood, then minimizing isotype
parsimony and mutabilities, if such information is provided.
parsimony, and finally maximizing a context-based poisson likelihood, if such information is provided.
Ranking priorities can be adjusted using the argument ``--ranking_coeffs``.

For example, to find the optimal tree
Expand Down
316 changes: 187 additions & 129 deletions gctree/branching_processes.py

Large diffs are not rendered by default.

26 changes: 24 additions & 2 deletions gctree/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def isotype_add(forest):
mutability_file=args.mutability,
substitution_file=args.substitution,
chain_split=args.chain_split,
branching_process_ranking_coeff=args.branching_process_ranking_coeff,
use_old_mut_parsimony=args.use_old_mut_parsimony,
)

if args.verbose:
Expand Down Expand Up @@ -535,7 +537,7 @@ def get_parser():
help=(
"when using concatenated heavy and light chains, this is the 0-based"
" index at which the 2nd chain begins, needed for determining coding frame in both chains,"
" and also to correctly calculate mutability parsimony."
" and also to correctly calculate context-based Poisson likelihood."
),
)
parser_infer.add_argument(
Expand Down Expand Up @@ -610,6 +612,16 @@ def get_parser():
"See a file excerpt in the documentation for :meth:`mutation_model.MutationModel`."
),
)
parser_infer.add_argument(
"--branching_process_ranking_coeff",
type=float,
default=-1,
help=(
"Coefficient used for branching process likelihood, when ranking trees by a linear "
"combination of traits. This value will be ignored if `--ranking_coeffs` argument is not "
"also provided."
),
)
parser_infer.add_argument(
"--ranking_coeffs",
type=float,
Expand All @@ -620,7 +632,17 @@ def get_parser():
"Coefficients are in order: isotype parsimony, mutation model parsimony, number of alleles. "
"A coefficient of -1 will be applied to branching process likelihood. "
"If not provided, trees will be ranked lexicographically by likelihood, "
"isotype parsimony, and mutability parsimony in that order."
"isotype parsimony, and context-based Poisson likelihood in that order."
),
)
parser_infer.add_argument(
"--use_old_mut_parsimony",
action="store_true",
help=(
"Use old mutability parsimony instead of poisson context likelihood. Not recommended "
"unless attempting to reproduce results from older versions of gctree. "
"This argument will have no effect unless an S5F model is provided with the arguments "
"`--mutability` and `--substitution`."
),
)
parser_infer.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion gctree/isotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_parser() -> argparse.ArgumentParser:
" nodes.\n\n"
"This tool doesn’t make any judgements about which tree is best.\n"
"Tree output order is the same as in gctree inference: ranking is\n"
"by log likelihood before isotype additions. A determination of\n"
"by branching process likelihood before isotype additions. A determination of\n"
"which is the best tree is left to the user, based on likelihoods,\n"
"isotype parsimony score, and changes in the number of nodes after\n"
"isotype additions.\n"
Expand Down
19 changes: 11 additions & 8 deletions gctree/isotyping.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def explode_idmap(
return newidmap


def _isotype_dagfuncs() -> hdag.utils.AddFuncDict:
def _isotype_dagfuncs() -> hdag.utils.HistoryDagFilter:
"""Return functions for filtering by isotype parsimony score on the history
DAG.
Expand Down Expand Up @@ -435,13 +435,16 @@ def edge_weight_func(n1: hdag.HistoryDagNode, n2: hdag.HistoryDagNode):
n1iso = list(n1isos.keys())[0]
return int(sum(isotype_distance(n1iso, n2iso) for n2iso in n2isos.keys()))

return hdag.utils.AddFuncDict(
{
"start_func": lambda n: 0,
"edge_weight_func": edge_weight_func,
"accum_func": sum,
},
name="Isotype Pars.",
return hdag.utils.HistoryDagFilter(
hdag.utils.AddFuncDict(
{
"start_func": lambda n: 0,
"edge_weight_func": edge_weight_func,
"accum_func": sum,
},
name="Isotype Pars.",
),
min,
)


Expand Down
142 changes: 102 additions & 40 deletions gctree/mutation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import historydag as hdag
from multiset import FrozenMultiset
from typing import Tuple, List, Callable, Optional
import itertools
import math


class MutationModel:
Expand Down Expand Up @@ -129,20 +131,25 @@ def mutability(self, kmer: str) -> Tuple[np.float64, np.float64]:
"sequence {} must contain only characters A, C, G, T, or N".format(kmer)
)

mutabilities_to_average, substitutions_to_average = zip(
*[self.context_model[x] for x in MutationModel._disambiguate(kmer)]
)

average_mutability = np.mean(mutabilities_to_average)
average_substitution = {
b: sum(
substitution_dict[b] for substitution_dict in substitutions_to_average
cached = self.context_model.get(kmer, None)
if cached is None:
mutabilities_to_average, substitutions_to_average = zip(
*[self.context_model[x] for x in MutationModel._disambiguate(kmer)]
)
/ len(substitutions_to_average)
for b in "ACGT"
}

return average_mutability, average_substitution
average_mutability = np.mean(mutabilities_to_average)
average_substitution = {
b: sum(
substitution_dict[b]
for substitution_dict in substitutions_to_average
)
/ len(substitutions_to_average)
for b in "ACGT"
}
cached = average_mutability, average_substitution
self.context_model[kmer] = cached

return cached

def mutabilities(self, sequence: str) -> List[Tuple[np.float64, np.float64]]:
r"""Returns the mutability of a sequence at each site, along with
Expand Down Expand Up @@ -440,7 +447,7 @@ def _sequence_disambiguations(sequence, _accum=""):

def _mutability_dagfuncs(
*args, splits: List[int] = [], **kwargs
) -> hdag.utils.AddFuncDict:
) -> hdag.utils.HistoryDagFilter:
"""Return functions for counting mutability parsimony on the history DAG.
Mutability parsimony of a tree is the sum over all edges in the tree
Expand Down Expand Up @@ -478,36 +485,38 @@ def distance(node1, node2):
else:
return dist(node1.label.sequence, node2.label.sequence)

return hdag.utils.AddFuncDict(
{"start_func": lambda n: 0, "edge_weight_func": distance, "accum_func": sum},
name="Mut. Pars.",
return hdag.utils.HistoryDagFilter(
hdag.utils.AddFuncDict(
{
"start_func": lambda n: 0,
"edge_weight_func": distance,
"accum_func": sum,
},
name="Mut. Pars.",
),
min,
)


def _mutability_distance_precursors(
mutation_model: MutationModel, splits: List[int] = []
):
chunk_idxs = list(zip([0] + splits, splits + [None]))
# Caching could be moved to the MutationModel class instead.
context_model = mutation_model.context_model.copy()
k = mutation_model.k
h = k // 2
# Build all sequences with (when k=5) one or two Ns on either end
templates = [
("N" * left, "N" * (k - left - right), "N" * right)
for left in range(h + 1)
for right in range(h + 1)
if left != 0 or right != 0
]

kmers_to_compute = [
leftns + stub + rightns
for leftns, ambig_stub, rightns in templates
for stub in _sequence_disambiguations(ambig_stub)
]
# Cache all these mutabilities in context_model also
context_model.update(
{kmer: mutation_model.mutability(kmer) for kmer in kmers_to_compute}

h = mutation_model.k // 2

# Pads sequence with N's, including in the chain-split boundary to
# avoid unrelated sites from being treated as part of each others' context.

# Indices at which padding N's will be in sequences returned from add_ns.
# Does not include indices of last two N's.
padding_indices = set(
itertools.chain.from_iterable(
[
range(split + idx * h, split + (idx + 1) * h)
for idx, split in enumerate([0] + splits)
]
)
)

def add_ns(seq: str):
Expand Down Expand Up @@ -535,16 +544,26 @@ def sum_minus_logp(pairs: FrozenMultiset):
p_arr = [
mult
* (
np.log(context_model[mer][0])
+ np.log(context_model[mer][1][newbase])
np.log(mutation_model.mutability(mer)[0])
+ np.log(mutation_model.mutability(mer)[1][newbase])
)
for (mer, newbase), mult in pairs
]
return -sum(p_arr)
else:
return 0.0

return (mutpairs, sum_minus_logp)
def mutability_sum(parent_seq):
padded_seq = add_ns(parent_seq)
for idx in padding_indices:
assert padded_seq[idx] == "N"
return sum(
mutation_model.mutability(padded_seq[idx - h : idx + h + 1])[0]
for idx, _ in enumerate(padded_seq[:-h])
if idx not in padding_indices
)

return (mutpairs, sum_minus_logp, mutability_sum)


def _mutability_distance(mutation_model: MutationModel, splits=[]):
Expand All @@ -562,11 +581,54 @@ def _mutability_distance(mutation_model: MutationModel, splits=[]):
Note that, in particular, this function is not symmetric on its arguments.
"""
mutpairs, sum_minus_logp = _mutability_distance_precursors(
mutpairs, sum_minus_logp, _ = _mutability_distance_precursors(
mutation_model, splits=splits
)

def distance(seq1, seq2):
return sum_minus_logp(mutpairs(seq1, seq2))

return distance


def _context_poisson_likelihood(mutation_model: MutationModel, splits=[]):
mutpairs, sum_minus_logp, mutability_sum = _mutability_distance_precursors(
mutation_model, splits=splits
)

def distance(seq1, seq2):
subs = mutpairs(seq1, seq2)
sub_count = len(subs)
if sub_count == 0:
return 0
else:
mut_sum = mutability_sum(seq1)
substitution_sum = -sum_minus_logp(subs)
return (
substitution_sum
+ (sub_count * (math.log(sub_count) - math.log(mut_sum)))
- sub_count
)

return distance


def _context_poisson_likelihood_dagfuncs(*args, splits: List[int] = [], **kwargs):
mutation_model = MutationModel(*args, **kwargs)
distance = _context_poisson_likelihood(mutation_model, splits=splits)

return hdag.utils.HistoryDagFilter(
hdag.utils.AddFuncDict(
{
"start_func": lambda n: 0,
"edge_weight_func": lambda n1, n2: (
0
if n1.is_ua_node()
else distance(n1.label.sequence, n2.label.sequence)
),
"accum_func": sum,
},
name="LogContextLikelihood",
),
max,
)
8 changes: 8 additions & 0 deletions tests/smalltest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ export MPLBACKEND=agg
mkdir -p tests/smalltest_output
wget -O HS5F_Mutability.csv https://bitbucket.org/kleinstein/shazam/raw/ba4b30fc6791e2cfd5712e9024803c53b136e664/data-raw/HS5F_Mutability.csv
wget -O HS5F_Substitution.csv https://bitbucket.org/kleinstein/shazam/raw/ba4b30fc6791e2cfd5712e9024803c53b136e664/data-raw/HS5F_Substitution.csv

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv --ranking_coeffs 1 1 0 --use_old_mut_parsimony --branching_process_ranking_coeff 0

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv --ranking_coeffs .01 -1 0 --branching_process_ranking_coeff -1 --summarize_forest --tree_stats

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv

gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv
8 changes: 4 additions & 4 deletions tests/test_isotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_trim_byisotype():
for node in tdag.preorder():
if node.attr is not None:
node.attr["isotype"] = node._dp_data
kwargs = _isotype_dagfuncs()
c = tdag.weight_count(**kwargs)
dag_filter = _isotype_dagfuncs()
c = tdag.weight_count(**dag_filter)
key = min(c)
count = c[key]
tdag.trim_optimal_weight(**kwargs, optimal_func=min)
assert tdag.weight_count(**kwargs) == {key: count}
tdag.trim_optimal_weight(**dag_filter)
assert tdag.weight_count(**dag_filter) == {key: count}
36 changes: 36 additions & 0 deletions tests/test_likelihoods.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gctree.branching_processes as bp
import gctree.phylip_parse as pp
import gctree.utils as utils
import gctree.mutation_model as mm
from math import log

import numpy as np
from multiset import FrozenMultiset
Expand Down Expand Up @@ -198,3 +200,37 @@ def test_recursion_depth():
bp.CollapsedTree._max_ll_cache = {}
with np.errstate(all="raise"):
bp.CollapsedTree._ll_genotype(2, 500, 0.4, 0.6)


def test_context_likelihood():
# These files will be present if pytest is run through `make test`.
mutation_model = mm.MutationModel(
mutability_file="HS5F_Mutability.csv", substitution_file="HS5F_Substitution.csv"
)
log_likelihood = mm._context_poisson_likelihood(mutation_model, splits=[])

parent_seq = "AAGAAA"
child_seq = "AATCAA"

term1 = sum(
log(
mutation_model.mutability(fivemer)[0]
* mutation_model.mutability(fivemer)[1][target_base]
)
for fivemer, target_base in [("AAGAA", "T"), ("AGAAA", "C")]
)
sum_mutabilities = sum(
mutation_model.mutability(fivemer)[0]
for fivemer in ["NNAAG", "NAAGA", "AAGAA", "AGAAA", "GAAAN", "AAANN"]
)
true_val = term1 + 2 * log(2 / sum_mutabilities) - 2
assert true_val == log_likelihood(parent_seq, child_seq)

# Now test chain split:
parent_seq = parent_seq + parent_seq
child_seq = child_seq + child_seq
# At index 6, the second concatenated sequence starts.
log_likelihood = mm._context_poisson_likelihood(mutation_model, splits=[6])

true_val = 2 * term1 + 4 * log(4 / (2 * sum_mutabilities)) - 4
assert true_val == log_likelihood(parent_seq, child_seq)

0 comments on commit e2910b7

Please sign in to comment.