Skip to content

Commit

Permalink
z_sample method (#363)
Browse files Browse the repository at this point in the history
* z_sample

* fixes

* fixes
  • Loading branch information
ordabayevy authored Oct 11, 2022
1 parent a0d5c74 commit 945e65d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
24 changes: 18 additions & 6 deletions tapqir/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,13 +922,26 @@ def ttfb(
help="Plot a binary or probabilistic rastergram",
prompt="Plot a binary rastergram?",
),
num_samples: int = typer.Option(
2000,
"--num-samples",
"-n",
help="Number of posterior samples",
prompt="Number of posterior samples",
),
num_iter: int = typer.Option(
15000,
"--num-iter",
"-it",
help="Number of iterations",
prompt="Number of iterations",
),
):
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
import pyro
import torch
from pyro import distributions as dist
from pyro.ops.stats import hpdi

from tapqir.models import models
Expand All @@ -947,12 +960,14 @@ def ttfb(
model = models[model](device="cpu", dtype="float")
try:
model.load(cd, data_only=False)
model.load_checkpoint(param_only=True)
except TapqirFileNotFoundError as err:
logger.exception(f"Failed to load {err.name} file")
return 1

z = model.params["p_specific"] > 0.5 if binary else model.params["p_specific"]
r_type = "binary" if binary else "probabilistic"
z_samples = model.z_sample(num_samples=num_samples)
for c in range(model.data.C):
# sorted on-target
ttfb = time_to_first_binding(z[: model.data.N, :, c])
Expand All @@ -979,10 +994,7 @@ def ttfb(
# prepare data
Tmax = model.data.F
torch.manual_seed(0)
z_samples = dist.Bernoulli(
model.params["z_probs"][: model.data.N, :, c]
).sample((2000,))
data = time_to_first_binding(z_samples)
data = time_to_first_binding(z_samples[..., c])

# use cuda
torch.set_default_tensor_type(torch.cuda.FloatTensor)
Expand All @@ -992,7 +1004,7 @@ def ttfb(
ttfb_model,
ttfb_guide,
lr=5e-3,
n_steps=15000,
n_steps=num_iter,
data=data.cuda(),
control=None,
Tmax=Tmax,
Expand Down
5 changes: 5 additions & 0 deletions tapqir/models/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,3 +704,8 @@ def pspecific(self) -> torch.Tensor:
@property
def z_map(self) -> torch.Tensor:
return torch.argmax(self.z_probs, dim=-1)

def z_sample(self, num_samples):
return dist.Categorical(self.params["z_probs"][: self.data.N]).sample(
(num_samples,)
)
15 changes: 14 additions & 1 deletion tapqir/models/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import funsor
import torch
import torch.distributions.constraints as constraints
from pyro.distributions.hmm import _logmatmulexp
from pyro.distributions.hmm import _logmatmulexp, _sequential_index
from pyro.ops.indexing import Vindex
from pyroapi import distributions as dist
from pyroapi import handlers, infer, pyro
Expand Down Expand Up @@ -653,3 +653,16 @@ def m_probs(self) -> torch.Tensor:
return Vindex(torch.permute(pyro.param("m_probs").data, (1, 2, 3, 4, 0)))[
..., self.z_map.long()
]

def z_sample(self, num_samples):
init_probs = pyro.param("z_trans").data[: self.data.N, 0, :, 0]
init_probs = init_probs.expand((num_samples,) + init_probs.shape)
x = dist.Categorical(init_probs).sample()
trans_probs = (
pyro.param("z_trans").data[: self.data.N, 1:].permute(0, 2, 1, 3, 4)
)
trans_probs = trans_probs.expand((num_samples,) + trans_probs.shape)
xs = dist.Categorical(trans_probs).sample()
xs = _sequential_index(xs)
x = Vindex(xs)[..., :, x]
return x.permute(0, 1, 3, 2)

0 comments on commit 945e65d

Please sign in to comment.