Skip to content

Commit

Permalink
update desc
Browse files Browse the repository at this point in the history
  • Loading branch information
cmarshak committed Jan 23, 2025
1 parent a40f613 commit 16f567b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/distmetrics/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _estimate_logit_params_via_streamed_patches(
unfold_gen = unfolding_stream(pre_imgs_stack_t, P, stride, batch_size)

for patch_batch, slices in tqdm(
unfold_gen, total=n_batches, desc='Chips Traversed', mininterval=2, disable=(not tqdm_enabled)
unfold_gen, total=n_batches, desc='Chips Traversed (dual pol)', mininterval=2, disable=(not tqdm_enabled)
):
chip_mean, chip_logvar = model(patch_batch)
for k, (sy, sx) in enumerate(slices):
Expand Down Expand Up @@ -395,7 +395,7 @@ def _estimate_logit_params_via_folding(
pred_means_p = torch.zeros(*target_chip_shape).to(device)
pred_logvars_p = torch.zeros(*target_chip_shape).to(device)

for i in tqdm(range(n_batches), desc='Chips Traversed', mininterval=2, disable=(not tqdm_enabled)):
for i in tqdm(range(n_batches), desc='Chips Traversed (dual pol)', mininterval=2, disable=(not tqdm_enabled)):
# change last dimension from P**2 to P, P; use -1 because won't always have batch_size as 0th dimension
batch_s = slice(batch_size * i, batch_size * (i + 1))
patch_batch = patches[batch_s, ...].view(-1, T, C, P, P)
Expand Down

0 comments on commit 16f567b

Please sign in to comment.