Skip to content

Commit

Permalink
Merge pull request #36 from HaochenWang1243/main
Browse files Browse the repository at this point in the history
Updated thinning algo according to NHP's Algorithm 2
  • Loading branch information
iLampard authored Aug 11, 2024
2 parents 53b7b7f + c67216f commit efe3cdd
Showing 1 changed file with 40 additions and 47 deletions.
87 changes: 40 additions & 47 deletions easy_tpp/model/torch_model/torch_thinning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from easy_tpp.utils import logger


class EventSampler(nn.Module):
Expand Down Expand Up @@ -34,6 +35,11 @@ def __init__(self, num_sample, num_exp, over_sample_rate, num_samples_boundary,

def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, intensity_fn,
compute_last_step_only):
# logger.critical(f'time_seq: {time_seq}')
# logger.critical(f'time_delta_seq: {time_delta_seq}')
# logger.critical(f'event_seq: {event_seq}')
# logger.critical(f'intensity_fn: {intensity_fn}')
# logger.critical(f'compute_last_step_only: {compute_last_step_only}')
"""Compute the upper bound of intensity at each event timestamp.
Args:
Expand All @@ -54,10 +60,10 @@ def compute_intensity_upper_bound(self, time_seq, time_delta_seq, event_seq, int
steps=self.num_samples_boundary,
device=self.device)[None, None, :]

# [batch_size, seq_len, num_sample]
# [batch_size, seq_len, num_samples_boundary]
dtime_for_bound_sampled = time_delta_seq[:, :, None] * time_for_bound_sampled

# [batch_size, seq_len, num_sample, event_num]
# [batch_size, seq_len, num_samples_boundary, event_num]
intensities_for_bound = intensity_fn(time_seq,
time_delta_seq,
event_seq,
Expand Down Expand Up @@ -120,34 +126,46 @@ def sample_uniform_distribution(self, intensity_upper_bound):

return unif_numbers

def sample_accept(self, unif_numbers, sample_rate, total_intensities):
def sample_accept(self, unif_numbers, sample_rate, total_intensities, exp_numbers):
"""Do the sample-accept process.
For each parallel draw, find its min criterion: if that < 1.0, the 1st (i.e. smallest) sampled time
with cri < 1.0 is accepted; if none is accepted, use boundary / maxsampletime for that draw
For the accumulated exp (delta) samples drawn for each event timestamp, find (from left to right) the first
that makes the criterion < 1 and accept it as the sampled next-event time. If all exp samples are rejected
(criterion >= 1), then we set the sampled next-event time dtime_max.
Args:
unif_numbers (tensor): [batch_size, max_len, num_sample, num_exp], sampled uniform random number.
sample_rate (tensor): [batch_size, max_len], sample rate (intensity).
total_intensities (tensor): [batch_size, seq_len, num_sample, num_exp]
exp_numbers (tensor): [batch_size, seq_len, num_sample, num_exp]: sampled exp numbers (delta in Algorithm 2).
Returns:
list: two tensors,
criterion, [batch_size, max_len, num_sample, num_exp]
who_has_accepted_times, [batch_size, max_len, num_sample]
result (tensor): [batch_size, seq_len, num_sample], sampled next-event times.
"""

# [batch_size, max_len, num_sample, num_exp]
criterion = unif_numbers * sample_rate[:, :, None, None] / total_intensities


# [batch_size, max_len, num_sample, num_exp]
masked_crit_less_than_1 = torch.where(criterion<1,1,0)

# [batch_size, max_len, num_sample]
min_cri_each_draw, _ = criterion.min(dim=-1)

# find out unif_numbers * sample_rate < intensity
non_accepted_filter = (1-masked_crit_less_than_1).all(dim=3)

# [batch_size, max_len, num_sample]
who_has_accepted_times = min_cri_each_draw < 1.0

return criterion, who_has_accepted_times
first_accepted_indexer = masked_crit_less_than_1.argmax(dim=3)

# [batch_size, max_len, num_sample,1]
# indexer must be unsqueezed to 4D to match the number of dimensions of exp_numbers
result_non_accepted_unfiltered = torch.gather(exp_numbers, 3, first_accepted_indexer.unsqueeze(3))

# [batch_size, max_len, num_sample,1]
result = torch.where(non_accepted_filter.unsqueeze(3), torch.tensor(self.dtime_max), result_non_accepted_unfiltered)

# [batch_size, max_len, num_sample]
result = result.squeeze(dim=-1)

return result

def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_boundary,
intensity_fn, compute_last_step_only=False):
Expand Down Expand Up @@ -177,7 +195,8 @@ def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_bou
# we apply fast approximation, i.e., re-use exp sample times for computation
# [batch_size, seq_len, num_exp]
exp_numbers = self.sample_exp_distribution(intensity_upper_bound)

exp_numbers = torch.cumsum(exp_numbers, dim=-1)

# 3. compute intensity at sampled times from exp distribution
# [batch_size, seq_len, num_exp, event_num]
intensities_at_sampled_times = intensity_fn(time_seq,
Expand All @@ -193,46 +212,20 @@ def draw_next_time_one_step(self, time_seq, time_delta_seq, event_seq, dtime_bou
# add one dim of num_sample: re-use the intensity for samples for prediction
# [batch_size, seq_len, num_sample, num_exp]
total_intensities = torch.tile(total_intensities[:, :, None, :], [1, 1, self.num_sample, 1])

# [batch_size, seq_len, num_sample, num_exp]
exp_numbers = torch.tile(exp_numbers[:, :, None, :], [1, 1, self.num_sample, 1])

# 4. draw uniform distribution
# [batch_size, seq_len, num_sample, num_exp]
unif_numbers = self.sample_uniform_distribution(intensity_upper_bound)

# 5. find out accepted intensities
# criterion, [batch_size, max_len, num_sample, num_exp]
# who_has_accepted_times, [batch_size, max_len, num_sample]
criterion, who_has_accepted_times = self.sample_accept(unif_numbers, intensity_upper_bound,
total_intensities)

# 6. find out accepted dtimes
sampled_dtimes_accepted = exp_numbers.clone()

# for unaccepted, use boundary/maxsampletime for that draw
sampled_dtimes_accepted[criterion >= 1.0] = exp_numbers.max() + 1.0

accepted_times_each_draw, accepted_id_each_draw = sampled_dtimes_accepted.min(dim=-1)

# 7. fill out result
dtime_boundary_ = dtime_boundary[:, -1:] if compute_last_step_only else dtime_boundary

# [batch_size, seq_len, num_sample]
dtime_boundary_ = torch.tile(dtime_boundary_[..., None], [1, 1, self.num_sample])

# [batch_size, seq_len, num_sample]
res = torch.ones_like(dtime_boundary_) * dtime_boundary_
res = self.sample_accept(unif_numbers, intensity_upper_bound, total_intensities, exp_numbers)

# [batch_size, seq_len, num_sample]
weights = torch.ones_like(dtime_boundary_)
weights /= weights.sum(dim=-1, keepdim=True)

res[who_has_accepted_times] = accepted_times_each_draw[who_has_accepted_times]
who_not_accept = ~who_has_accepted_times

who_reach_further = exp_numbers[..., -1] > dtime_boundary_

res[who_not_accept & who_reach_further] = exp_numbers[..., -1][who_not_accept & who_reach_further]

weights = torch.ones_like(res)/res.shape[2]

# add a upper bound here in case it explodes, e.g., in ODE models
return res.clamp(max=1e5), weights

0 comments on commit efe3cdd

Please sign in to comment.