From b8d8328000df30902f0f47ab221a88f78ca881b2 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Wed, 10 Jul 2024 15:41:09 -0700 Subject: [PATCH] remove oopsie committed files --- cyto_dl/models/contrastive/maximum_entropy.py | 142 ------------- cyto_dl/models/contrastive/triplet.py | 134 ------------ .../models/contrastive/triplet_nucmorph.py | 196 ------------------ 3 files changed, 472 deletions(-) delete mode 100644 cyto_dl/models/contrastive/maximum_entropy.py delete mode 100644 cyto_dl/models/contrastive/triplet.py delete mode 100644 cyto_dl/models/contrastive/triplet_nucmorph.py diff --git a/cyto_dl/models/contrastive/maximum_entropy.py b/cyto_dl/models/contrastive/maximum_entropy.py deleted file mode 100644 index d7b47449..00000000 --- a/cyto_dl/models/contrastive/maximum_entropy.py +++ /dev/null @@ -1,142 +0,0 @@ -import sys -from pathlib import Path - -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -from torchmetrics import MeanMetric - -from cyto_dl.models.base_model import BaseModel -from cyto_dl.nn.maximum_entropy import MECLoss - -from sklearn.decomposition import PCA -import matplotlib.pyplot as plt - - - -class MaximumEntropy(BaseModel): - def __init__( - self, - batch_size: int, - epochs: int, - n_iter_per_epoch: int, - eps_d: float = 64, - save_every_n_epochs: int = 100, - save_dir: str = "./", - *, - model: nn.Module, - **base_kwargs, - ): - """ - Parameters - ---------- - model: nn.Module - model network, parameters are shared between task heads - x_key: str - key of input image in batch - save_dir="./" - directory to save images during training and validation - save_images_every_n_epochs=1 - Frequency to save out images during training - compile: False - Whether to compile the model using torch.compile - **base_kwargs: - Additional arguments passed to BaseModel - """ - _DEFAULT_METRICS = { - "train/loss": MeanMetric(), - "val/loss": MeanMetric(), - "test/loss": MeanMetric(), - } - metrics = base_kwargs.pop("metrics", _DEFAULT_METRICS) - super().__init__(metrics=metrics, **base_kwargs) - - self.model = model - self.loss_fn = MECLoss - # mu is only used for MEC metric - d = self.model.dim - eps_d /= d - print(f"eps_d: {eps_d}") - lamda = 1/(batch_size * eps_d) - - self.lambda_scheduler = self.lamda_scheduler(8/lamda, 1/lamda, epochs, n_iter_per_epoch, warmup_epochs=10) - self.momentum_scheduler = self.cosine_scheduler(0.996, 1, epochs, n_iter_per_epoch) - self.automatic_optimization = False - - def cosine_scheduler(self, base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0): - warmup_schedule = np.array([]) - warmup_iters = warmup_epochs * niter_per_ep - if warmup_epochs > 0: - warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) - - iters = np.arange(epochs * niter_per_ep - warmup_iters) - schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) - - schedule = np.concatenate((warmup_schedule, schedule)) - assert len(schedule) == epochs * niter_per_ep - return schedule - - - def lamda_scheduler(self, start_warmup_value, base_value, epochs, niter_per_ep, warmup_epochs=5): - warmup_schedule = np.array([]) - warmup_iters = warmup_epochs * niter_per_ep - if warmup_epochs > 0: - warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) - - schedule = np.ones(epochs * niter_per_ep - warmup_iters) * base_value - schedule = np.concatenate((warmup_schedule, schedule)) - assert len(schedule) == epochs * niter_per_ep - return schedule - - - def forward(self, x1, x2): - return self.model(x1, x2) - - def plot_classes(self, predictions, labels): - # calculate pca on predictions and label by labels - pca = PCA(n_components=2) - predictions= predictions.detach().cpu().numpy() - labels = labels.detach().cpu().numpy() - pca.fit(predictions) - pca_predictions = pca.transform(predictions) - - # plot pca - fig, ax = plt.subplots() - scatter = ax.scatter(pca_predictions[:, 0], pca_predictions[:, 1], c=labels) - legend1 = ax.legend(*scatter.legend_elements(), title="Classes") - ax.add_artist(legend1) - fig.savefig(Path(self.hparams.save_dir) / f"{self.current_epoch}_pca.png") - plt.close(fig) - - def model_step(self, stage, batch, batch_idx): - if 'label' not in batch: - z1, z2, p1, p2 = self.forward(batch["image"], batch["image_aug"]) - lambda_inv = self.lambda_scheduler[self.global_step] - momentum = self.momentum_scheduler[self.global_step] - mec_loss = (self.loss_fn(p1, z2, lambda_inv) + self.loss_fn(p2, z1, lambda_inv)) * 0.5 / self.hparams.batch_size - loss = -1 * mec_loss * lambda_inv - if stage == 'train' and batch_idx==0 and (self.current_epoch + 1) % self.hparams.save_every_n_epochs == 0: - self.plot_classes(z1, batch["target"]) - else: - class_pred= self.forward(batch['image']) - loss = self.loss_fn(class_pred, batch['label']) - # calculate accuracy - pred = torch.argmax(class_pred, dim=1) - acc = torch.sum(pred == batch['label']).item() / len(pred) - self.log(f"{stage}/acc", acc) - - if stage == 'train': - opt = self.optimizers() - opt.zero_grad() - self.manual_backward(loss) - opt.step() - - # momentum update of the parameters of the teacher network - with torch.no_grad(): - for param_q, param_k in zip(self.model.encoder.parameters(), self.model.teacher.parameters()): - param_k.data.mul_(momentum).add_((1 - momentum) * param_q.detach().data) - - - - return loss, None, None diff --git a/cyto_dl/models/contrastive/triplet.py b/cyto_dl/models/contrastive/triplet.py deleted file mode 100644 index 24bb5e88..00000000 --- a/cyto_dl/models/contrastive/triplet.py +++ /dev/null @@ -1,134 +0,0 @@ -from pathlib import Path - -import torch -import torch.nn as nn -import torch.nn.functional -from torchmetrics import MeanMetric - -from cyto_dl.models.base_model import BaseModel - -from sklearn.decomposition import PCA -import matplotlib.pyplot as plt - - - -class Triplet(BaseModel): - def __init__( - self, - save_every_n_epochs: int = 100, - save_dir: str = "./", - *, - model: nn.Module, - **base_kwargs, - ): - """ - Parameters - ---------- - model: nn.Module - model network, parameters are shared between task heads - x_key: str - key of input image in batch - save_dir="./" - directory to save images during training and validation - save_images_every_n_epochs=1 - Frequency to save out images during training - compile: False - Whether to compile the model using torch.compile - **base_kwargs: - Additional arguments passed to BaseModel - """ - _DEFAULT_METRICS = { - # "train/loss/mean_positive_dist": MeanMetric(), - # "train/loss/closest_negative_dist": MeanMetric(), - # "val/loss/mean_positive_dist": MeanMetric(), - # "val/loss/closest_negative_dist": MeanMetric(), - "train/loss": MeanMetric(), - "val/loss": MeanMetric(), - "test/loss": MeanMetric(), - } - metrics = base_kwargs.pop("metrics", _DEFAULT_METRICS) - super().__init__(metrics=metrics, **base_kwargs) - - self.model = model - self.loss_fn = torch.nn.TripletMarginLoss(margin=1.0) - - def forward(self, x1, x2): - return self.model(x1, x2) - - def plot_classes(self, predictions, labels): - # calculate pca on predictions and label by labels - pca = PCA(n_components=2) - predictions= predictions.detach().cpu().numpy() - labels = labels.detach().cpu().numpy() - pca.fit(predictions) - pca_predictions = pca.transform(predictions) - - # plot pca - fig, ax = plt.subplots() - scatter = ax.scatter(pca_predictions[:, 0], pca_predictions[:, 1], c=labels) - legend1 = ax.legend(*scatter.legend_elements(), title="Classes") - ax.add_artist(legend1) - fig.savefig(Path(self.hparams.save_dir) / f"{self.current_epoch}_pca.png") - plt.close(fig) - - def find_hard_negatives(self, pairwise_dist, negatives_mask): - # exlude positives - pairwise_dist[~negatives_mask] = torch.inf - hard_negative_idx = torch.argmin(pairwise_dist, dim=1) - return hard_negative_idx - - def find_hard_positives(self, pairwise_dist, negatives_mask): - # exclude negatives and self - pairwise_dist[negatives_mask] = -torch.inf - pairwise_dist[torch.eye(pairwise_dist.shape[0]).bool()] = -torch.inf - - hard_positive_idx = torch.argmax(pairwise_dist, dim=1) - return hard_positive_idx - - - def model_step(self, stage, batch, batch_idx): - anchor_embeddings = self.model(batch["image"].squeeze(1)) - anchor_embeddings= torch.nn.functional.normalize(anchor_embeddings, p=2, dim=1) - - # positive_embeddings = self.model(batch['image_aug'].squeeze(1)) - # positive_embeddings= torch.nn.functional.normalize(positive_embeddings, p=2, dim=1) - # find pairwisel2 distance between embeddings - # pairwise_dist = torch.cdist(anchor_embeddings, positive_embeddings, p=2) - - pairwise_dist = torch.cdist(anchor_embeddings, anchor_embeddings, p=2) - targets = batch["target"].unsqueeze(1).float() - negatives_mask = torch.cdist(targets, targets, p=0).bool() - - hard_negative_idx = self.find_hard_negatives(pairwise_dist.clone(), negatives_mask) - negative_embeddings = anchor_embeddings[hard_negative_idx] - - hard_positive_idx = self.find_hard_positives(pairwise_dist.clone(), negatives_mask) - positive_embeddings = anchor_embeddings[hard_positive_idx] - - # # count how many hard negatives are used per label - # hard_negative_counts = torch.bincount(batch["target"][hard_negative_idx]) - # print(hard_negative_counts) - - #reorder anchor embeddings to be matched as negative embeddings - # negative_embeddings = anchor_embeddings[hard_negative_idx] - - # find triplet loss - loss = self.loss_fn(anchor_embeddings, positive_embeddings, negative_embeddings) - - # with torch.no_grad(): - # loss ={ - # 'loss': loss, - # 'mean_positive_dist': torch.mean(pairwise_dist[~negatives_mask]).item(), - # 'closest_negative_dist': torch.mean(torch.diagonal(pairwise_dist[hard_negative_idx])).item() - # } - - if stage == 'val' and batch_idx==0: - print('AVG HARD NEGATIVE DISTANCE:', pairwise_dist[torch.arange(pairwise_dist.shape[0]), hard_negative_idx].mean()) - print('AVG HARD POSITIVE DISTANCE:', pairwise_dist[torch.arange(pairwise_dist.shape[0]), hard_positive_idx].mean()) - self.plot_classes(anchor_embeddings, batch["target"]) - # from aicsimageio.writers import OmeTiffWriter - # OmeTiffWriter.save(uri=Path(self.hparams.save_dir) / f"{self.current_epoch}_anchors.tiff", data = batch['image'].squeeze().detach().cpu().numpy()) - # OmeTiffWriter.save(uri=Path(self.hparams.save_dir) / f"{self.current_epoch}_positives.tiff", data = batch['image_aug'].squeeze().detach().cpu().numpy()) - # OmeTiffWriter.save(uri=Path(self.hparams.save_dir) / f"{self.current_epoch}_negatives.tiff", data = batch['image'][max_indices].squeeze().detach().cpu().numpy()) - - return loss, None, None diff --git a/cyto_dl/models/contrastive/triplet_nucmorph.py b/cyto_dl/models/contrastive/triplet_nucmorph.py deleted file mode 100644 index 6e8b3f58..00000000 --- a/cyto_dl/models/contrastive/triplet_nucmorph.py +++ /dev/null @@ -1,196 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional -from torchmetrics import MeanMetric -from sklearn.decomposition import PCA -import matplotlib.pyplot as plt - -from cyto_dl.models.base_model import BaseModel -from pathlib import Path -import pandas as pd -import numpy as np - -class Triplet(BaseModel): - def __init__( - self, - save_dir, - *, - model: nn.Module, - **base_kwargs, - ): - """ - Parameters - ---------- - model: nn.Module - model network, parameters are shared between task heads - x_key: str - key of input image in batch - save_dir="./" - directory to save images during training and validation - save_images_every_n_epochs=1 - Frequency to save out images during training - compile: False - Whether to compile the model using torch.compile - **base_kwargs: - Additional arguments passed to BaseModel - """ - _DEFAULT_METRICS = { - "train/loss": MeanMetric(), - "val/loss": MeanMetric(), - "test/loss": MeanMetric(), - } - metrics = base_kwargs.pop("metrics", _DEFAULT_METRICS) - super().__init__(metrics=metrics, **base_kwargs) - - self.model = model - self.loss_fn = torch.nn.TripletMarginLoss(margin=1.0) - - - def forward(self, x1, x2): - return self.model(x1, x2) - - - def plot_classes(self, anchor_embeddings, positive_embeddings, negative_embeddings): - # calculate pca on predictions and label by labels - pca = PCA(n_components=2) - pca.fit(anchor_embeddings) - - random_examples = np.random.choice(anchor_embeddings.shape[0], 10) - anchor_embeddings = pca.transform(anchor_embeddings)[random_examples] - positive_embeddings = pca.transform(positive_embeddings)[random_examples] - negative_embeddings = pca.transform(negative_embeddings)[random_examples] - - fig, ax = plt.subplots() - - # plot anchor embeddings in gray - ax.scatter(anchor_embeddings[:, 0], anchor_embeddings[:, 1], c='gray') - - # plot positive embeddings in green - ax.scatter(positive_embeddings[:, 0], positive_embeddings[:, 1], c='green') - - # plot negative embeddings in red - ax.scatter(negative_embeddings[:, 0], negative_embeddings[:, 1], c='red') - - - # draw lines between anchor and positive, anchor and negative - ax.plot([anchor_embeddings[:, 0], positive_embeddings[:, 0]], [anchor_embeddings[:, 1], positive_embeddings[:, 1]], 'green') - ax.plot([anchor_embeddings[:, 0], negative_embeddings[:, 0]], [anchor_embeddings[:, 1], negative_embeddings[:, 1]], 'red', alpha=0.1) - - - fig.savefig(Path(self.hparams.save_dir) / f"{self.current_epoch}_pca.png") - plt.close(fig) - - - def model_step(self, stage, batch, batch_idx): - anchor_embeddings = self.model(batch["anchor"]) - anchor_embeddings= torch.nn.functional.normalize(anchor_embeddings, p=2, dim=1) - - positive_embeddings = self.model(batch['positive']) - positive_embeddings= torch.nn.functional.normalize(positive_embeddings, p=2, dim=1) - - negative_embeddings = self.model(batch['negative']) - negative_embeddings= torch.nn.functional.normalize(negative_embeddings, p=2, dim=1) - - # find triplet loss - loss = self.loss_fn(anchor_embeddings, positive_embeddings, negative_embeddings) - - if stage == 'val' and batch_idx == 0: - with torch.no_grad(): - self.plot_classes(anchor_embeddings.detach().cpu().numpy(), positive_embeddings.detach().cpu().numpy(), negative_embeddings.detach().cpu().numpy()) - - return loss, None, None - - # def predict_step(self, batch, batch_idx): - # from monai import transforms - # import tqdm - - # cell_ids = batch['cell_id'] - # anchor = batch['anchor'] - # embeddings_anchor = self.model(anchor).detach().cpu().numpy() - - - # embeddings_anchor = pd.DataFrame(embeddings_anchor, columns = [str(i) for i in range(embeddings_anchor.shape[1])]) - # embeddings_anchor['cell_id'] = cell_ids - # embeddings_anchor['name'] = 'anchor' - - # embeddings_negative = self.model(batch['negative']).detach().cpu().numpy() - # embeddings_negative = pd.DataFrame(embeddings_negative, columns = [str(i) for i in range(embeddings_negative.shape[1])]) - # embeddings_negative['cell_id'] = cell_ids - # embeddings_negative['name'] = 'negative' - - - # positive_embeddings = [] - # name = [] - # cellid = [] - # # augs = [] - # for img in tqdm.tqdm(range(anchor.shape[0])): - # aug = anchor[img].clone() - # for i in range(100): - # i_aug = transforms.RandGridDistortion(prob=1)(aug) - # # augs.append(i_aug.detach().cpu().numpy()) - # positive_embeddings.append(self.model(i_aug.unsqueeze(0)).detach().cpu().numpy()) - # name += [f"grid_distort_{i}"] - # cellid += [cell_ids[img]] - - # for std in np.linspace(0.1, 3.0, 5): - # for i in range(100): - # i_aug = transforms.RandGaussianNoise(prob=1, std=std)(aug) - # # augs.append(i_aug.detach().cpu().numpy()) - # positive_embeddings.append(self.model(i_aug.unsqueeze(0)).detach().cpu().numpy()) - # name += [f"gaussian_noise_{std}"] - # cellid += [cell_ids[img]] - - # for i in range(4): - # i_aug = transforms.Rotate90(k=i, spatial_axes=(1, 2))(aug) - # positive_embeddings.append(self.model(i_aug.unsqueeze(0)).detach().cpu().numpy()) - # name += [f"rotate_90_{i}"] - # cellid += [cell_ids[img]] - # for hflip in [True, False]: - # i_aug = transforms.Flip(spatial_axis=1)(aug) - # positive_embeddings.append(self.model(i_aug.unsqueeze(0)).detach().cpu().numpy()) - # name += [f"hflip_{hflip}"] - # cellid += [cell_ids[img]] - # for vflip in [True, False]: - # i_aug = transforms.Flip(spatial_axis=2)(aug) - # positive_embeddings.append(self.model(i_aug.unsqueeze(0)).detach().cpu().numpy()) - # name += [f"vflip_{vflip}"] - # cellid += [cell_ids[img]] - # for _ in range(100): - # i_aug = transforms.RandHistogramShift(prob=1, num_control_points=(80, 120))(aug) - # positive_embeddings.append(self.model(i_aug.unsqueeze(0)).detach().cpu().numpy()) - # name += [f"intensity"] - # cellid += [cell_ids[img]] - - # # create csv with batch x embeddings and cell_ids - # positive_embeddings = np.stack(positive_embeddings).squeeze(1) - # positive_embeddings = pd.DataFrame(positive_embeddings, columns = [str(i) for i in range(positive_embeddings.shape[1])]) - # positive_embeddings['cell_id'] = cellid - # positive_embeddings['name'] = name - - - # all = pd.concat([embeddings_anchor, embeddings_negative, positive_embeddings]) - # all.to_csv(Path(self.hparams.save_dir) / f"{batch_idx}_embeddings.csv", index=False) - - # from aicsimageio.writers import OmeTiffWriter - # OmeTiffWriter.save(uri = Path(self.hparams.save_dir) / f"{batch_idx}_aug.ome.tiff", data = np.stack(augs), dimension_order='CZYX') - - # OmeTiffWriter.save(uri = Path(self.hparams.save_dir) / f"{batch_idx}_anchor.ome.tiff", data = anchor.detach().cpu().numpy(), dimension_order='CZYX') - - # OmeTiffWriter.save(uri = Path(self.hparams.save_dir) / f"{batch_idx}_negative.ome.tiff", data = batch['negative'].detach().cpu().numpy(), dimension_order='CZYX') - - # quit() - - - - - - - def predict_step(self, batch, batch_idx): - cell_ids = batch['cell_id'] - embeddings = self.model(batch['anchor']).detach().cpu().numpy() - - # create csv with batch x embeddings and cell_ids - data = pd.DataFrame(embeddings, columns = [str(i) for i in range(embeddings.shape[1])]) - data['cell_id'] = cell_ids - data.to_csv(Path(self.hparams.save_dir) / f"{batch_idx}_embeddings.csv", index=False) -