Skip to content

Commit

Permalink
v0.0.4 (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
cmarshak authored Jan 23, 2025
2 parents 2bb5914 + 95def2e commit 68b2bdd
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 105 deletions.
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,22 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [PEP 440](https://www.python.org/dev/peps/pep-0440/)
and uses [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.0.4] - 2025-01-18

### Fixed
* Fixed transformer to ensure last pre-image always has correct index

### Added
* Arxiv link to README
* Installation instructions
* `tqdm` description for despeckling
* Allow user to disable `tqdm` for despeckle

### Changed
* Renamed `load_trained_transformer_model` to `load_transformer_model`
* Renamed inputs to reflect (possible) usage with other polarizations: `vv` -> `copol` and `vh` -> `crosspol`. The APIs don't change, just the variable name inputs.


## [0.0.3] - 2025-01-16

### Fixed
Expand Down
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ It is worth noting that other metrics can be generated from the above using `+`,

## Installation

...
We recommend using the `conda/mamba` package manager to install this library.

```
mamba install -c conda-forge distmetrics
```

You can also use `pip`, although this doesn't ensure proper dependencies are installed.


### For development
Expand All @@ -55,7 +61,7 @@ get_device() # should be `cuda` if GPU is available or `mps` if using mac M chip

# References

<a id=1>[1]</a> H. Hardiman Mostow et al., "Deep Self-Supervised Disturbance Mapping with Sentinel-1 OPERA RTC Synthetic Aperture Radar", *in preparation 2024*.
<a id=1>[1]</a> H. Hardiman Mostow et al., "Deep Self-Supervised Disturbance Mapping with Sentinel-1 OPERA RTC Synthetic Aperture Radar", [arXiv](https://arxiv.org/abs/2409.15568).

<a id=2>[2]</a> O. L. Stephenson et al., "Deep Learning-Based Damage Mapping With InSAR Coherence Time Series," in IEEE Transactions on Geoscience and Remote Sensing, vol. 60, pp. 1-17, 2022, Art no. 5207917, doi: 10.1109/TGRS.2021.3084209. https://arxiv.org/abs/2105.11544

Expand Down
168 changes: 87 additions & 81 deletions notebooks/transformer-on-landslide.ipynb

Large diffs are not rendered by default.

15 changes: 13 additions & 2 deletions src/distmetrics/despeckle.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,22 @@ def despeckle_one_rtc_arr_with_tv(X: np.ndarray, reg_param: float = 5, noise_flo


def despeckle_rtc_arrs_with_tv(
arrs: list[np.ndarray], reg_param: float = 5, noise_floor_db: float = -22, n_jobs: int = 10
arrs: list[np.ndarray],
reg_param: float = 5,
noise_floor_db: float = -22,
n_jobs: int = 10,
tqdm_enabled: bool = True,
) -> list[np.ndarray]:
def dspkl(X: np.ndarray) -> np.ndarray:
return despeckle_one_rtc_arr_with_tv(X, reg_param=reg_param, noise_floor_db=noise_floor_db)

with WorkerPool(n_jobs=n_jobs, use_dill=True) as pool:
arrs_dspk = pool.map(dspkl, arrs, progress_bar=True, progress_bar_style='std', concatenate_numpy_output=False)
arrs_dspk = pool.map(
dspkl,
arrs,
progress_bar=tqdm_enabled,
progress_bar_style='std',
concatenate_numpy_output=False,
progress_bar_options=dict(desc='Despeckling', dynamic_ncols=True),
)
return arrs_dspk
56 changes: 36 additions & 20 deletions src/distmetrics/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.nn.functional as F
from pydantic import BaseModel, model_validator
from scipy.special import logit
from tqdm import tqdm
from tqdm.auto import tqdm

from distmetrics.mahalanobis import _transform_pre_arrs
from distmetrics.model_data.transformer_config import transformer_config, transformer_latest_config
Expand Down Expand Up @@ -145,7 +145,12 @@ def forward(self, img_baseline: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso
) # batch, seq_len, num_patches, data_dim

img_baseline = (
self.embedding(img_baseline) + self.spatial_pos_embed + self.temporal_pos_embed[:, :seq_len, :, :]
self.embedding(img_baseline)
+ self.spatial_pos_embed
# changed self.temporal_pos_embed[:, :seq_len, :, :]
# to self.temporal_pos_embed[:, (self.max_seq_len-seq_len):, :, :] to ensure last pre-image always has
# correct index
+ self.temporal_pos_embed[:, (self.max_seq_len - seq_len) :, :, :]
) # batch, seq_len, num_patches, d_model

img_baseline = img_baseline.view(
Expand Down Expand Up @@ -217,8 +222,8 @@ def load_transformer_model(model_token: str = 'latest') -> SpatioTemporalTransfo
@torch.no_grad()
def _estimate_logit_params_via_streamed_patches(
model: torch.nn.Module,
pre_imgs_vv: list[np.ndarray],
pre_imgs_vh: list[np.ndarray],
imgs_copol: list[np.ndarray],
imgs_crosspol: list[np.ndarray],
stride: int = 2,
batch_size: int = 32,
max_nodata_ratio: float = 0.1,
Expand Down Expand Up @@ -256,7 +261,7 @@ def _estimate_logit_params_via_streamed_patches(
device = get_device()

# stack to T x 2 x H x W
pre_imgs_stack = _transform_pre_arrs(pre_imgs_vv, pre_imgs_vh)
pre_imgs_stack = _transform_pre_arrs(imgs_copol, imgs_crosspol)
pre_imgs_stack = pre_imgs_stack.astype('float32')

# Mask
Expand Down Expand Up @@ -288,7 +293,12 @@ def _estimate_logit_params_via_streamed_patches(
unfold_gen = unfolding_stream(pre_imgs_stack_t, P, stride, batch_size)

for patch_batch, slices in tqdm(
unfold_gen, total=n_batches, desc='Chips Traversed', mininterval=2, disable=(not tqdm_enabled)
unfold_gen,
total=n_batches,
desc='Chips Traversed',
mininterval=2,
disable=(not tqdm_enabled),
dynamic_ncols=True,
):
chip_mean, chip_logvar = model(patch_batch)
for k, (sy, sx) in enumerate(slices):
Expand All @@ -313,8 +323,8 @@ def _estimate_logit_params_via_streamed_patches(
@torch.no_grad()
def _estimate_logit_params_via_folding(
model: torch.nn.Module,
pre_imgs_vv: list[np.ndarray],
pre_imgs_vh: list[np.ndarray],
imgs_copol: list[np.ndarray],
imgs_crosspol: list[np.ndarray],
stride: int = 2,
batch_size: int = 32,
tqdm_enabled: bool = True,
Expand Down Expand Up @@ -352,7 +362,7 @@ def _estimate_logit_params_via_folding(
device = get_device()

# stack to T x 2 x H x W
pre_imgs_stack = _transform_pre_arrs(pre_imgs_vv, pre_imgs_vh)
pre_imgs_stack = _transform_pre_arrs(imgs_copol, imgs_crosspol)
pre_imgs_stack = pre_imgs_stack.astype('float32')

# Mask
Expand Down Expand Up @@ -390,7 +400,13 @@ def _estimate_logit_params_via_folding(
pred_means_p = torch.zeros(*target_chip_shape).to(device)
pred_logvars_p = torch.zeros(*target_chip_shape).to(device)

for i in tqdm(range(n_batches), desc='Chips Traversed', mininterval=2, disable=(not tqdm_enabled)):
for i in tqdm(
range(n_batches),
desc='Chips Traversed',
mininterval=2,
disable=(not tqdm_enabled),
dynamic_ncols=True,
):
# change last dimension from P**2 to P, P; use -1 because won't always have batch_size as 0th dimension
batch_s = slice(batch_size * i, batch_size * (i + 1))
patch_batch = patches[batch_s, ...].view(-1, T, C, P, P)
Expand Down Expand Up @@ -430,8 +446,8 @@ def _estimate_logit_params_via_folding(

def estimate_normal_params_of_logits(
model: torch.nn.Module,
pre_imgs_vv: list[np.ndarray],
pre_imgs_vh: list[np.ndarray],
imgs_copol: list[np.ndarray],
imgs_crosspol: list[np.ndarray],
stride: int = 2,
batch_size: int = 32,
tqdm_enabled: bool = True,
Expand All @@ -445,17 +461,17 @@ def estimate_normal_params_of_logits(
)

mu, sigma = estimate_logits(
model, pre_imgs_vv, pre_imgs_vh, stride=stride, batch_size=batch_size, tqdm_enabled=tqdm_enabled
model, imgs_copol, imgs_crosspol, stride=stride, batch_size=batch_size, tqdm_enabled=tqdm_enabled
)
return mu, sigma


def compute_transformer_zscore(
model: torch.nn.Module,
pre_imgs_vv: list[np.ndarray],
pre_imgs_vh: list[np.ndarray],
post_arr_vv: np.ndarray,
post_arr_vh: np.ndarray,
pre_imgs_copol: list[np.ndarray],
pre_imgs_crosspol: list[np.ndarray],
post_arr_copol: np.ndarray,
post_arr_crosspol: np.ndarray,
stride: int = 4,
batch_size: int = 32,
tqdm_enabled: bool = True,
Expand All @@ -481,15 +497,15 @@ def compute_transformer_zscore(

mu, sigma = estimate_normal_params_of_logits(
model,
pre_imgs_vv,
pre_imgs_vh,
pre_imgs_copol,
pre_imgs_crosspol,
stride=stride,
batch_size=batch_size,
tqdm_enabled=tqdm_enabled,
memory_strategy=memory_strategy,
)

post_arr_logit_s = logit(np.stack([post_arr_vv, post_arr_vh], axis=0))
post_arr_logit_s = logit(np.stack([post_arr_copol, post_arr_crosspol], axis=0))
z_score_dual = np.abs(post_arr_logit_s - mu) / sigma
z_score = agg(z_score_dual, axis=0)
m_dist = DiagMahalanobisDistance2d(dist=z_score, mean=mu, std=sigma)
Expand Down

0 comments on commit 68b2bdd

Please sign in to comment.