Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Morris committed Aug 2, 2024
1 parent 602a3e0 commit 402bfe8
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 133 deletions.
60 changes: 35 additions & 25 deletions cyto_dl/image/transforms/generate_jepa_masks.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
from monai.transforms import RandomizableTransform
from typing import Tuple

import numpy as np
from skimage.segmentation import find_boundaries
from einops import rearrange
from typing import Tuple
from monai.transforms import RandomizableTransform
from skimage.segmentation import find_boundaries


class JEPAMaskGenerator(RandomizableTransform):
"""Transform for generating Block-contiguous masks for JEPA training.
This class works by randomly adding mask blocks until the mask_ratio is met or exceeded, then
removing blocks from the exterior of the contiguous mask until the mask_ratio is met exactly.
"""
Transform for generating Block-contiguous masks for JEPA training. This class works by randomly adding mask blocks until the mask_ratio is met or exceeded, then
removing blocks from the exterior of the contiguous mask until the mask_ratio is met exactly.
"""
def __init__(self, mask_size:int=12, block_aspect_ratio: Tuple[float]=(0.5,1.5), num_patches: Tuple[float]=(6, 24, 24), mask_ratio: float=0.9):

def __init__(
self,
mask_size: int = 12,
block_aspect_ratio: Tuple[float] = (0.5, 1.5),
num_patches: Tuple[float] = (6, 24, 24),
mask_ratio: float = 0.9,
):
"""
Parameters
----------
Expand All @@ -24,8 +33,10 @@ def __init__(self, mask_size:int=12, block_aspect_ratio: Tuple[float]=(0.5,1.5),
The proportion of the image to be masked
"""
assert mask_ratio < 1, "mask_ratio must be less than 1"
assert mask_size * max(block_aspect_ratio) < min(num_patches[-2:]), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches"

assert mask_size * max(block_aspect_ratio) < min(
num_patches[-2:]
), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches"

self.mask_size = mask_size
self.block_aspect_ratio = block_aspect_ratio
self.num_patches = num_patches
Expand All @@ -43,12 +54,11 @@ def __init__(self, mask_size:int=12, block_aspect_ratio: Tuple[float]=(0.5,1.5),
self.edge_mask[1:-1, 1:-1] = 0
else:
raise ValueError("num_patches must be 2 or 3 dimensions")

def remove_excess_pixels(self, mask):
"""
Remove pixels along the boundary of the mask until the target number of pixels is reached
"""
bound = find_boundaries(mask, mode='inner')
"""Remove pixels along the boundary of the mask until the target number of pixels is
reached."""
bound = find_boundaries(mask, mode="inner")
# include image edge as boundary, not just 1:0 transitions
edge_mask = np.logical_and(mask, self.edge_mask)
bound = np.logical_or(bound, edge_mask)
Expand All @@ -59,10 +69,10 @@ def remove_excess_pixels(self, mask):
remove_coords = bound_coords[remove]
if self.spatial_dims == 3:
mask[remove_coords[:, 0], remove_coords[:, 1], remove_coords[:, 2]] = 0
mask = rearrange(mask, 'z y x -> (z y x)').astype(bool)
mask = rearrange(mask, "z y x -> (z y x)").astype(bool)
else:
mask[remove_coords[:, 0], remove_coords[:, 1]] = 0
mask = rearrange(mask, 'y x -> (y x)').astype(bool)
mask = rearrange(mask, "y x -> (y x)").astype(bool)
return mask

def __call__(self, img_dict):
Expand All @@ -73,21 +83,21 @@ def __call__(self, img_dict):
while mask.sum() < self.target_pix:
# randomly select block shape
aspect_ratio = self.R.uniform(*self.block_aspect_ratio)
width = int(self.mask_size*aspect_ratio)
height = int(self.mask_size/aspect_ratio)
width = int(self.mask_size * aspect_ratio)
height = int(self.mask_size / aspect_ratio)
# randomly select block position
x = self.R.randint(0, self.num_patches[-1]-width+1)
y = self.R.randint(0, self.num_patches[-2]-height+1)
x = self.R.randint(0, self.num_patches[-1] - width + 1)
y = self.R.randint(0, self.num_patches[-2] - height + 1)
# add block to mask
if self.spatial_dims == 3:
mask[:, y:y+height, x:x+width] = 1
mask[:, y : y + height, x : x + width] = 1
else:
mask[y:y+height, x:x+width] = 1
mask[y : y + height, x : x + width] = 1

mask = self.remove_excess_pixels(mask)

context_mask = np.argwhere(~mask).squeeze()
target_mask = np.argwhere(mask).squeeze()
img_dict['context_mask'] = context_mask
img_dict['target_mask'] = target_mask
return img_dict
img_dict["context_mask"] = context_mask
img_dict["target_mask"] = target_mask
return img_dict
2 changes: 1 addition & 1 deletion cyto_dl/models/jepa/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .jepa_base import JEPABase
from .jepa_base import JEPABase
73 changes: 41 additions & 32 deletions cyto_dl/models/jepa/ijepa.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,57 @@
import torch.nn as nn

from cyto_dl.models.jepa import JEPABase


class IJEPA(JEPABase):
def __init__(
self,
*,
encoder: nn.Module,
predictor: nn.Module,
x_key: str,
save_dir: str= './',
momentum: float=0.998,
max_epochs: int=100,
self,
*,
encoder: nn.Module,
predictor: nn.Module,
x_key: str,
save_dir: str = "./",
momentum: float = 0.998,
max_epochs: int = 100,
**base_kwargs,
):
"""JEPA for self-supervised learning on 2D and 3D images.
Parameters
----------
encoder : nn.Module
The encoder module used for feature extraction.
predictor : nn.Module
The predictor module used for generating predictions.
x_key : str
The key used to access the input data.
momentum : float, optional
The momentum value for the exponential moving average of the model weights (default is 0.998).
max_epochs : int, optional
The maximum number of training epochs (default is 100).
**base_kwargs : dict
Additional arguments passed to the BaseModel.
"""
super().__init__(
encoder=encoder,
predictor=predictor,
x_key=x_key,
save_dir=save_dir,
momentum=momentum,
max_epochs=max_epochs,
**base_kwargs,
):
"""
Initialize the IJEPA model.
Parameters
----------
encoder : nn.Module
The encoder module used for feature extraction.
predictor : nn.Module
The predictor module used for generating predictions.
x_key : str
The key used to access the input data.
momentum : float, optional
The momentum value for the exponential moving average of the model weights (default is 0.998).
max_epochs : int, optional
The maximum number of training epochs (default is 100).
**base_kwargs : dict
Additional arguments passed to the BaseModel.
"""
super().__init__(encoder=encoder, predictor=predictor, x_key=x_key, save_dir=save_dir, momentum=momentum, max_epochs=max_epochs, **base_kwargs)
)

def model_step(self, stage, batch, batch_idx):
self.update_teacher()
input=batch[self.hparams.x_key]
input = batch[self.hparams.x_key]

target_masks = self.get_mask(batch, 'target_mask')
context_masks = self.get_mask(batch, 'context_mask')
target_masks = self.get_mask(batch, "target_mask")
context_masks = self.get_mask(batch, "context_mask")

target_embeddings = self.get_target_embeddings(input, target_masks)
context_embeddings = self.get_context_embeddings(input, context_masks)
predictions= self.predictor(context_embeddings, target_masks)
predictions = self.predictor(context_embeddings, target_masks)

loss = self.loss(predictions, target_embeddings)
return loss, None, None
69 changes: 43 additions & 26 deletions cyto_dl/models/jepa/iwm.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
from pathlib import Path

import pandas as pd
import torch
import torch.nn as nn
from einops import rearrange
from cyto_dl.models.base_model import BaseModel
import pandas as pd
from pathlib import Path
from einops import repeat
from einops import rearrange, repeat

class IJEPA(BaseModel):
from cyto_dl.models.jepa import JEPABase


class IWM(JEPABase):
def __init__(
self,
*,
encoder: nn.Module,
predictor: nn.Module,
x_key: str,
save_dir: str= './',
momentum: float=0.998,
max_epochs: int=100,
save_dir: str = "./",
momentum: float = 0.998,
max_epochs: int = 100,
**base_kwargs,
):
"""
Initialize the IJEPA model.
"""Image World Model for self-supervised learning of encoder and predictor of
transformations in image latent space.
Parameters
----------
Expand All @@ -36,46 +38,61 @@ def __init__(
**base_kwargs : dict
Additional arguments passed to the BaseModel.
"""
super().__init__(encoder=encoder, predictor=predictor, x_key=x_key, save_dir=save_dir, momentum=momentum, max_epochs=max_epochs, **base_kwargs)
super().__init__(
encoder=encoder,
predictor=predictor,
x_key=x_key,
save_dir=save_dir,
momentum=momentum,
max_epochs=max_epochs,
**base_kwargs,
)

def model_step(self, stage, batch, batch_idx):
self.update_teacher()
source = batch[f'{self.hparams.x_key}_brightfield']
target = batch[f'{self.hparams.x_key}_struct']
source = batch[f"{self.hparams.x_key}_brightfield"]
target = batch[f"{self.hparams.x_key}_struct"]

target_masks = self.get_mask(batch, 'target_mask')
context_masks = self.get_mask(batch, 'context_mask')
target_masks = self.get_mask(batch, "target_mask")
context_masks = self.get_mask(batch, "context_mask")
target_embeddings = self.get_target_embeddings(target, target_masks)
context_embeddings = self.get_context_embeddings(source, context_masks)
predictions= self.predictor(context_embeddings, target_masks, batch['structure_name'])
predictions = self.predictor(context_embeddings, target_masks, batch["structure_name"])

loss = self.loss(predictions, target_embeddings)
return loss, None, None

def get_predict_masks(self, batch_size, num_patches=[4, 16, 16]):
mask = torch.ones(num_patches, dtype=bool)
mask = rearrange(mask, 'z y x -> (z y x)')
mask = rearrange(mask, "z y x -> (z y x)")
mask = torch.argwhere(mask).squeeze()

return repeat(mask, 't -> t b', b=batch_size)
return repeat(mask, "t -> t b", b=batch_size)

def predict_step(self, batch, batch_idx):
source = batch[f'{self.hparams.x_key}_brightfield'].squeeze(0)
target = batch[f'{self.hparams.x_key}_struct'].squeeze(0)
source = batch[f"{self.hparams.x_key}_brightfield"].squeeze(0)
target = batch[f"{self.hparams.x_key}_struct"].squeeze(0)

# use predictor to predict each patch
target_masks = self.get_predict_masks(source.shape[0])
# mean across patches
bf_embeddings = self.encoder(source)
pred_target_embeddings = self.predictor(bf_embeddings, target_masks, batch['structure_name']).mean(axis=1)
pred_feats = pd.DataFrame(pred_target_embeddings.detach().cpu().numpy(), columns=[f'{i}_pred' for i in range(pred_target_embeddings.shape[1])])
pred_target_embeddings = self.predictor(
bf_embeddings, target_masks, batch["structure_name"]
).mean(axis=1)
pred_feats = pd.DataFrame(
pred_target_embeddings.detach().cpu().numpy(),
columns=[f"{i}_pred" for i in range(pred_target_embeddings.shape[1])],
)

# get target embeddings
target_embeddings = self.encoder(target).mean(axis=1)
ctxt_feats = pd.DataFrame(target_embeddings.detach().cpu().numpy(), columns=[f'{i}_ctxt' for i in range(target_embeddings.shape[1])])
ctxt_feats = pd.DataFrame(
target_embeddings.detach().cpu().numpy(),
columns=[f"{i}_ctxt" for i in range(target_embeddings.shape[1])],
)

all_feats = pd.concat([ctxt_feats, pred_feats], axis=1)

all_feats.to_csv(Path(self.hparams.save_dir) / f"{batch_idx}_predictions.csv")
return None, None, None

Loading

0 comments on commit 402bfe8

Please sign in to comment.