From 16f567bc39d7e4b96aa4927ed910f4f78ab82ca1 Mon Sep 17 00:00:00 2001 From: Charlie Marshak Date: Wed, 22 Jan 2025 17:53:07 -0800 Subject: [PATCH] update desc --- src/distmetrics/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distmetrics/transformer.py b/src/distmetrics/transformer.py index 3c5d907..6cd9236 100644 --- a/src/distmetrics/transformer.py +++ b/src/distmetrics/transformer.py @@ -293,7 +293,7 @@ def _estimate_logit_params_via_streamed_patches( unfold_gen = unfolding_stream(pre_imgs_stack_t, P, stride, batch_size) for patch_batch, slices in tqdm( - unfold_gen, total=n_batches, desc='Chips Traversed', mininterval=2, disable=(not tqdm_enabled) + unfold_gen, total=n_batches, desc='Chips Traversed (dual pol)', mininterval=2, disable=(not tqdm_enabled) ): chip_mean, chip_logvar = model(patch_batch) for k, (sy, sx) in enumerate(slices): @@ -395,7 +395,7 @@ def _estimate_logit_params_via_folding( pred_means_p = torch.zeros(*target_chip_shape).to(device) pred_logvars_p = torch.zeros(*target_chip_shape).to(device) - for i in tqdm(range(n_batches), desc='Chips Traversed', mininterval=2, disable=(not tqdm_enabled)): + for i in tqdm(range(n_batches), desc='Chips Traversed (dual pol)', mininterval=2, disable=(not tqdm_enabled)): # change last dimension from P**2 to P, P; use -1 because won't always have batch_size as 0th dimension batch_s = slice(batch_size * i, batch_size * (i + 1)) patch_batch = patches[batch_s, ...].view(-1, T, C, P, P)