Skip to content

Commit

Permalink
Merge pull request #267 from hyanwong/iterative-progress
Browse files Browse the repository at this point in the history
Output number of iterations using EP
  • Loading branch information
hyanwong authored Jun 14, 2023
2 parents 26b5081 + 2fa8dcc commit 667e6f2
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 37 deletions.
88 changes: 71 additions & 17 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,74 @@ def test_default_values_preprocess(self):
assert args.trim_telomeres


class TestEndToEnd:
class RunCLI:
def run_tsdate_cli(self, input_ts, cmd=""):
with tempfile.TemporaryDirectory() as tmpdir:
input_filename = pathlib.Path(tmpdir) / "input.trees"
input_ts.dump(input_filename)
output_filename = pathlib.Path(tmpdir) / "output.trees"
full_cmd = "date " + str(input_filename) + f" {output_filename} " + cmd
cli.tsdate_main(full_cmd.split())
return tskit.load(output_filename)


class TestOutput(RunCLI):
"""
Tests for the command-line output.
"""

popsize = 1

def test_bad_method(self, capfd):
bad = "bad_method"
input_ts = msprime.simulate(4, random_seed=123)
cmd = f"--method {bad}"
with pytest.raises(SystemExit):
_ = self.run_tsdate_cli(input_ts, f"{self.popsize} " + cmd)
captured = capfd.readouterr()
assert bad in captured.err

def test_no_output(self, capfd):
input_ts = msprime.simulate(4, random_seed=123)
_ = self.run_tsdate_cli(input_ts, f"{self.popsize}")
(out, err) = capfd.readouterr()
assert out == ""
assert err == ""

def test_progress(self, capfd):
input_ts = msprime.simulate(4, random_seed=123)
cmd = "--method inside_outside --progress"
_ = self.run_tsdate_cli(input_ts, f"{self.popsize} " + cmd)
(out, err) = capfd.readouterr()
assert out == ""
# run_tsdate_cli print logging to stderr
desc = (
"Find Node Spans",
"TipCount",
"Calculating Node Age Variances",
"Find Mixture Priors",
"Inside",
"Outside",
"Constrain Ages",
)
for match in desc:
assert match in err
assert err.count("100%") == len(desc)
assert err.count("it/s") >= len(desc)

def test_iterative_progress(self, capfd):
input_ts = msprime.simulate(4, random_seed=123)
cmd = "--method variational_gamma --mutation-rate 1e-8 --progress"
_ = self.run_tsdate_cli(input_ts, f"{self.popsize} " + cmd)
(out, err) = capfd.readouterr()
assert out == ""
# run_tsdate_cli print logging to stderr
assert err.count("Expectation Propagation: 100%") == 2
assert err.count("EP (iter 2, rootwards): 100%") == 1
assert err.count("rootwards): 100%") == err.count("leafwards): 100%")


class TestEndToEnd(RunCLI):
"""
Class to test input to CLI outputs dated tree sequences.
"""
Expand Down Expand Up @@ -196,29 +263,16 @@ def ts_equal(self, ts1, ts2, times_equal=False):
assert t1.nodes == t2.nodes

def verify(self, input_ts, cmd):
with tempfile.TemporaryDirectory() as tmpdir:
input_filename = pathlib.Path(tmpdir) / "input.trees"
input_ts.dump(input_filename)
output_filename = pathlib.Path(tmpdir) / "output.trees"
full_cmd = "date " + str(input_filename) + f" {output_filename} " + cmd
cli.tsdate_main(full_cmd.split())
output_ts = tskit.load(output_filename)
output_ts = self.run_tsdate_cli(input_ts, cmd)
assert input_ts.num_samples == output_ts.num_samples
self.ts_equal(input_ts, output_ts)

def compare_python_api(self, input_ts, cmd, Ne, mutation_rate, method):
with tempfile.TemporaryDirectory() as tmpdir:
input_filename = pathlib.Path(tmpdir) / "input.trees"
input_ts.dump(input_filename)
output_filename = pathlib.Path(tmpdir) / "output.trees"
full_cmd = "date " + str(input_filename) + f" {output_filename} " + cmd
cli.tsdate_main(full_cmd.split())
output_ts = tskit.load(output_filename)
output_ts = self.run_tsdate_cli(input_ts, cmd)
dated_ts = tsdate.date(
input_ts, population_size=Ne, mutation_rate=mutation_rate, method=method
)
# print(dated_ts.tables.nodes.time, output_ts.tables.nodes.time)
assert np.array_equal(dated_ts.tables.nodes.time, output_ts.tables.nodes.time)
assert np.array_equal(dated_ts.nodes_time, output_ts.nodes_time)

def test_ts(self):
input_ts = msprime.simulate(10, random_seed=1)
Expand Down
15 changes: 8 additions & 7 deletions tsdate/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,15 @@ def tsdate_cli_parser():
)
parser.add_argument(
"--method",
type=str,
choices=["inside_outside", "maximization", "variational_gamma"],
default="inside_outside",
help="Specify which estimation method to use: can be \
'inside_outside' (empirically better, theoretically \
problematic), 'maximization' (worse empirically, especially \
with a gamma approximated prior, but theoretically robust), or \
'variational_gamma' (a fast experimental continuous-time \
approximation). Default: 'inside_outside'",
help=(
"Specify which estimation method to use: "
"'inside_outside' is empirically better, but theoretically problematic, "
"'maximization' is worse empirically, especially with a gamma prior, but "
"theoretically robust), 'variational_gamma' is a fast experimental "
"continuous-time approximation). Current default: 'inside_outside'",
),
)
parser.add_argument(
"--ignore-oldest",
Expand Down
33 changes: 20 additions & 13 deletions tsdate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,19 +992,14 @@ def __init__(self, *args, **kwargs):
)
# self.factor_norm[edge.id] += ... # TODO

def propagate(self, *, edges, progress=None):
def propagate(self, *, edges, desc=None, progress=None):
"""
Update approximating factor for each edge
"""
if progress is None:
progress = self.progress
# TODO: this will still converge if parallelized (potentially slower)
for edge in tqdm(
edges,
desc="Expectation Propagation",
total=self.ts.num_edges,
disable=not progress,
):
for edge in tqdm(edges, desc, total=self.ts.num_edges, disable=not progress):
if edge.child in self.fixednodes:
continue
if edge.parent in self.fixednodes:
Expand Down Expand Up @@ -1042,13 +1037,22 @@ def propagate(self, *, edges, progress=None):
# TODO not complete
self.factor_norm[edge.id] = norm_const

def iterate(self, *, progress=None, **kwargs):
def iterate(self, *, iter_num=None, progress=None):
"""
Update edge factors from leaves to root then from root to leaves,
and return approximate log marginal likelihood
"""
self.propagate(edges=self.edges_by_parent_asc(grouped=False), progress=progress)
self.propagate(edges=self.edges_by_child_desc(grouped=False), progress=progress)
desc = "Expectation Propagation"
if iter_num: # Show iteration number if not first iteration
desc = f"EP (iter {iter_num + 1:>2}, rootwards)"
self.propagate(
edges=self.edges_by_parent_asc(grouped=False), desc=desc, progress=progress
)
if iter_num:
desc = f"EP (iter {iter_num + 1:>2}, leafwards)"
self.propagate(
edges=self.edges_by_child_desc(grouped=False), desc=desc, progress=progress
)
# TODO
# marginal_lik = np.sum(self.factor_norm)
# return marginal_lik
Expand Down Expand Up @@ -1112,7 +1116,10 @@ def constrain_ages_topo(ts, post_mn, eps, nodes_to_date=None, progress=False):
parents_unique = np.unique(parents, return_index=True)
parent_indices = parents_unique[1][np.isin(parents_unique[0], nodes_to_date)]
for index, nd in tqdm(
enumerate(sorted(nodes_to_date)), desc="Constrain Ages", disable=not progress
enumerate(sorted(nodes_to_date)),
desc="Constrain Ages",
total=len(nodes_to_date),
disable=not progress,
):
if index + 1 != len(nodes_to_date):
children_index = np.arange(parent_indices[index], parent_indices[index + 1])
Expand Down Expand Up @@ -1530,8 +1537,8 @@ def variational_dates(
)

dynamic_prog = ExpectationPropagation(priors, liklhd, progress=progress)
for _ in range(max_iterations):
dynamic_prog.iterate()
for it in range(max_iterations):
dynamic_prog.iterate(iter_num=it)
posterior = dynamic_prog.posterior
tree_sequence, mn_post, _ = variational_mean_var(
tree_sequence, posterior, fixed_node_set=fixed_nodes
Expand Down

0 comments on commit 667e6f2

Please sign in to comment.