Skip to content

Commit

Permalink
adding other paper code
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Nov 21, 2024
1 parent 9398cac commit 537c65d
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 42 deletions.
2 changes: 1 addition & 1 deletion cellpose/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

MODEL_NAMES = []
for ctype in ["cyto3", "cyto2", "nuclei"]:
for ntype in ["denoise", "deblur", "upsample"]:
for ntype in ["denoise", "deblur", "upsample", "oneclick"]:
MODEL_NAMES.append(f"{ntype}_{ctype}")
if ctype != "cyto3":
for ltype in ["per", "seg", "rec"]:
Expand Down
101 changes: 101 additions & 0 deletions paper/3.0/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
import torch
from torch import nn
from tqdm import trange

# in same folder
try:
Expand Down Expand Up @@ -299,6 +300,106 @@ def real_examples(folder):
dat2["ap_n2v"] = ap_n2v
np.save(root / "n2v_masks.npy", dat2)

def real_examples_ribo(root):
navgs = [1, 2, 4, 8, 16, 32, 64]
noisy = [[], [], [], [], [], [], []]
clean = []
for i in [1, 3, 6, 4, 5]:
imgs = io.imread(Path(root) / f"denoise_{i:05d}_00001.tif")[:300]
imgs = [imgs[:, :512, :512], imgs[:, 512:, :512], imgs[:, :512, 512:], imgs[:, 512:, 512:]]
clean.extend([img.mean(axis=0) for img in imgs])
for n, navg in enumerate(navgs):
iavg = np.linspace(0, len(imgs[0])-1, navg+2).astype(int)[1:-1]
noisy[n].extend(np.array([img[iavg].mean(axis=0) for img in imgs]))
print(len(clean), len(noisy[0]))

thresholds = np.arange(0.5, 1.05, 0.05)
diameter = 17
normalize = True # {"tile_norm_blocksize": 80}
seg_model = models.Cellpose(gpu=True, model_type="cyto2")
model = denoise.DenoiseModel(gpu=True, model_type="denoise_cyto2")
masks = seg_model.eval(clean, diameter=diameter, channels=[0,0],
normalize=normalize)[0]
ap_noisy = np.zeros((len(noisy), len(noisy[0]), len(thresholds)))
ap_dn = np.zeros((len(noisy), len(noisy[0]), len(thresholds)))
dat = {}
dat["navgs"] = navgs
dat["imgs_dn"] = []
dat["masks_dn"] = []
dat["masks_noisy"] = []
dat["masks_clean"] = masks
dat["noisy"] = noisy
dat["clean"] = clean
for n, imgs in enumerate(noisy):
masks_noisy = seg_model.eval(imgs, diameter=diameter, channels=[0,0],
normalize=normalize)[0]
img_dn = model.eval(imgs, diameter=diameter, channels=[0,0],
normalize=normalize)
ap, tp, fp, fn = metrics.average_precision(masks, masks_noisy, threshold=thresholds)
ap_noisy[n] = ap
masks_dn = seg_model.eval(img_dn, diameter=diameter, channels=[0,0],
normalize=normalize)[0]
ap, tp, fp, fn = metrics.average_precision(masks, masks_dn, threshold=thresholds)
ap_dn[n] = ap
dat["imgs_dn"].append(img_dn)
dat["masks_dn"].append(masks_dn)
dat["masks_noisy"].append(masks_noisy)
print(ap_noisy[n,:,0].mean(axis=0), ap_dn[n,:,0].mean(axis=0))
dat["ap_noisy"] = ap_noisy
dat["ap_dn"] = ap_dn
np.save(Path(root) / "ribo_denoise.npy", dat)

dat = {}
dat["navgs"] = navgs
dat["imgs_n2s"] = []
dat["masks_n2s"] = []
dat["masks_clean"] = masks
dat["noisy"] = noisy
dat["clean"] = clean
dat["ap_n2s"] = np.zeros((len(noisy), len(noisy[0]), len(thresholds)))

for n, imgs in enumerate(noisy):
imgs_n2s = []
for i in trange(len(imgs)):
out = noise2self.train_per_image(imgs[i][np.newaxis,...].astype("float32"))
imgs_n2s.append(out)
imgs_n2s = np.array(imgs_n2s)
masks_n2s = seg_model.eval(imgs_n2s, diameter=diameter, channels=[0,0])[0]
ap, tp, fp, fn = metrics.average_precision(masks, masks_n2s, threshold=thresholds)
dat["ap_n2s"][n] = ap
dat["imgs_n2s"].append(imgs_n2s)
dat["masks_n2s"].append(masks_n2s)
print(n, ap.mean(axis=0)[[0, 5, 8]])

np.save(Path(root) / "ribo_denoise_n2s.npy", dat)

dat = {}
dat["navgs"] = navgs
dat["imgs_n2v"] = []
dat["masks_n2v"] = []
dat["masks_clean"] = masks
dat["noisy"] = noisy
dat["clean"] = clean
dat["ap_n2v"] = np.zeros((len(noisy), len(noisy[0]), len(thresholds)))

for n, imgs in enumerate(noisy):
imgs_n2v = []
for i in trange(len(imgs)):
out = noise2void.train_per_image(imgs[i].astype("float32"))
imgs_n2v.append(out)
imgs_n2v = np.array(imgs_n2v)
masks_n2v = seg_model.eval(imgs_n2v, diameter=diameter, channels=[0,0],
normalize=normalize)[0]
ap, tp, fp, fn = metrics.average_precision(masks, masks_n2v, threshold=thresholds)
#print(ap[:,0])
dat["ap_n2v"][n] = ap
dat["imgs_n2v"].append(imgs_n2v)
dat["masks_n2v"].append(masks_n2v)
print(n, ap.mean(axis=0)[[0, 5, 8]])

np.save(Path(root) / "ribo_denoise_n2v.npy", dat)



def specialist_training(root):
""" root is path to specialist images (first 89 images of cyto2 and first 11 test images) """
Expand Down
186 changes: 145 additions & 41 deletions paper/3.0/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,9 +1204,9 @@ def suppfig_specialist(folder, save_fig=True):

il = 0

fig = plt.figure(figsize=(9, 5), dpi=100)
fig = plt.figure(figsize=(9, 9), dpi=100)
yratio = 9 / 5
grid = plt.GridSpec(2, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1,
grid = plt.GridSpec(3, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1,
wspace=0.15, hspace=0.2)

titles = ["train - clean", "train - noisy", "test - noisy"]
Expand Down Expand Up @@ -1265,32 +1265,46 @@ def suppfig_specialist(folder, save_fig=True):
ax.set_xticks(np.arange(0.5, 1.05, 0.1))
ax.set_xlim([0.5, 1.0])

transl = mtransforms.ScaledTranslation(-10 / 72, 20 / 72, fig.dpi_scale_trans)
grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[1:, :], wspace=0.05,
hspace=0.1)

kk = [2, 3, 4, 10]
transl = mtransforms.ScaledTranslation(-10 / 72, 25 / 72, fig.dpi_scale_trans)

kk = [2, 3, 4, 6, 10]
iex = 8
ylim = [10, 310]
xlim = [100, 500]
ylim = [125, 512] # [0, 350]
xlim = [50, 325] # [100, 500]
legstr0[-1] = u"\u2013 Cellpose3 (per. + seg.)"
for j, k in enumerate(kk):
ax = plt.subplot(grid[1, j])
pos = ax.get_position().bounds
ax.set_position([pos[0], pos[1] - 0.07, pos[2], pos[3]])
img0 = imgs_all[k][iex].squeeze()
img0 *= 1.1
img0 = np.clip(img0, 0, 1)
outlines_gt = utils.outlines_list(masks_all[0][iex].T.copy(), multiprocessing=False)
for ii in range(2):
ax = plt.subplot(grid1[ii, j])
pos = ax.get_position().bounds
ax.set_position([pos[0], pos[1] - 0.07 + ii*0.03, pos[2], pos[3]])
img0 = imgs_all[k][iex].squeeze().T
masks0 = masks_all[k][iex].squeeze().T
img0 *= 1.
img0 = np.clip(img0, 0, 1)

ax.imshow(img0, cmap="gray", vmin=0, vmax=1)
ax.axis("off")
ax.set_ylim(ylim)
ax.set_xlim(xlim)
ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium")
ax.text(1, -0.04, f"[email protected] = {aps[k,iex,0] : 0.2f}", va="top", ha="right",
transform=ax.transAxes)
if j == 0:
il = plot_label(ltr, il, ax, transl, fs_title)
ax.text(0.02, 1.2, "Denoised test image", fontsize="large",
fontstyle="italic", transform=ax.transAxes)
ax.imshow(img0, cmap="gray", vmin=0, vmax=1)
if ii==1:
outlines = utils.outlines_list(masks0, multiprocessing=False)
for o in outlines_gt:
ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=2)
for o in outlines:
ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--")
ax.axis("off")
ax.set_ylim(ylim)
ax.set_xlim(xlim)
if ii==0:
ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium")
else:
ax.text(1, -0.04, f"[email protected] = {aps[k,iex,0] : 0.2f}", va="top", ha="right",
transform=ax.transAxes)
if j == 0 and ii==0:
il = plot_label(ltr, il, ax, transl, fs_title)
ax.text(0.02, 1.15, "Denoised test image", fontsize="large",
fontstyle="italic", transform=ax.transAxes)

print(aps.mean(axis=1)[:, [0, 5, 8]])

Expand Down Expand Up @@ -1493,9 +1507,9 @@ def fig6(folder, save_fig=True):

diams = [utils.diameters(lbl)[0] for lbl in lbls]

gen_model = "/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039"
gen_model = "oneclick_cyto3" #"/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039"
model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean,
pretrained_model=gen_model)
model_type=gen_model)
seg_model = models.CellposeModel(gpu=True, model_type="cyto3")
pscales = [1.5, 20., 1.5, 1., 5., 40., 3.]
denoise.deterministic()
Expand Down Expand Up @@ -1561,6 +1575,7 @@ def fig6(folder, save_fig=True):
legstr0 = ["", u"\u2013 noisy image", u"\u2013 original",
u"\u2013 noise-specific", "\u2013 data-specific", u"-- one-click"]
theight = [0, 0,4,3,2,1]
cstr = ["noisy\nimage", "blurry\nimage", "bilinear\nupsampled"]
for i in range(6):
ctype = "cellpose test set" if i < 3 else "nuclei test set"
noise_type = ["denoising", "deblurring", "upsampling"][i % 3]
Expand All @@ -1580,7 +1595,7 @@ def fig6(folder, save_fig=True):
if i == 1 or i == 4:
ax.text(0.5, 1.18, ctype, transform=ax.transAxes, ha="center",
fontsize="large")

ax.text(0.03, 0.03, cstr[i%3], transform=ax.transAxes, fontsize="small")
ax.set_ylim([0, 0.72])
ax.set_xticks(np.arange(0.5, 1.05, 0.25))
ax.set_xlim([0.5, 1.0])
Expand All @@ -1593,9 +1608,98 @@ def fig6(folder, save_fig=True):
]
colsj = cols0[[0, 1, -1]]

ly0 = 250
generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api,
titlesj, colsj, titlesi, j0=0, il=il)

if save_fig:
os.makedirs("figs/", exist_ok=True)
fig.savefig("figs/fig6.pdf", dpi=150)

def suppfig_generalist_examples(folder, save_fig=True):
cols0 = np.array([[0, 0, 0], [0, 0, 0], [0, 128, 0], [180, 229, 162],
[246, 198, 173], [192, 71, 29], ])
cols0 = cols0 / 255
titlesi = [
"Tissuenet", "Livecell", "Yeaz bright-field", "YeaZ phase-contrast",
"Omnipose phase-contrast", "Omnipose fluorescent", "DeepBacs"
]
colsj = cols0[[0, 1, -1]]
folders = [
"cyto2", "nuclei", "tissuenet", "livecell", "yeast_BF", "yeast_PhC",
"bact_phase", "bact_fluor", "deepbacs"
]
diam_mean = 30.

#iexs = [340, 50, 10, 5, 70, 2, 33]
iexs = [305, 1071, 0, 3, 70, 9, 31]
imgs, lbls = [[], [], []], []
masks = [[], [], []]
for f, iex in zip(folders[2:], iexs):
dat = np.load(Path(folder) / f"{f}_generalist_masks.npy",
allow_pickle=True).item()
img = dat["imgs"][iex].copy()
img = img[:1] if img.ndim > 2 else img
img = np.maximum(0, transforms.normalize99(img))
imgs[0].append(img)
masks[0].append(dat["masks_pred"][iex])
lbls.append(dat["masks"][iex].astype("uint16"))

diams = [utils.diameters(lbl)[0] for lbl in lbls]

transl = mtransforms.ScaledTranslation(-15 / 72, 30 / 72, fig.dpi_scale_trans)
gen_model = "oneclick_cyto3"
model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean,
model_type=gen_model)
seg_model = models.CellposeModel(gpu=True, model_type="cyto3")

fig = plt.figure(figsize=(14, 8), dpi=100)
grid = plt.GridSpec(4, 14, figure=fig, left=0.02, right=0.97, top=0.97, bottom=0.03)

for ii in range(2):
if ii==0:
titlesj = ["clean", "blurry", "deblurred (one-click)"]
else:
titlesj = ["clean", "downsampled", "upsampled (one-click)"]
masks[1] = []
masks[2] = []
imgs[1] = []
imgs[2] = []
sigmas = [5., 3., 7., 12., 5., 5., 3.]
ds = [6,4,8,8,6,6,6]
denoise.deterministic()
for i, img in tqdm(enumerate(imgs[0])):
img0 = torch.from_numpy(img.copy()).squeeze().unsqueeze(0).unsqueeze(0)
img0 = img0.float()
noisy0 = denoise.add_noise(img0, poisson=0., downsample=1. if ii==1 else 0,
blur=1., ds=ds[i] if ii==1 else 0,
sigma0 = sigmas[i] if ii==0 else sigmas[i]/2,
sigma1 = sigmas[i] if ii==0 else sigmas[i]/2,
pscale=120.).numpy().squeeze()
denoised0 = model.eval(noisy0, diameter=diams[i], normalize=True)

imgs[1].append(noisy0)
imgs[2].append(denoised0)
for j in range(1, 3):
masks[j].append(
seg_model.eval(
imgs[j][i], diameter=diams[i], channels=[0, 0], tile_overlap=0.5,
flow_threshold=0.4, augment=True, bsize=224,
niter=2000 if folders[i - 2] == "bact_phase" else None)[0])
api = np.array(
[metrics.average_precision(lbls, masks[i])[0][:, 0] for i in range(3)])

generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api,
titlesj, colsj, titlesi, j0=-1 + 2*ii, letter=True)
if save_fig:
os.makedirs("figs/", exist_ok=True)
fig.savefig("figs/suppfig_genex.pdf", dpi=150)

def generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api,
titlesj, colsj, titlesi, j0=0, ly0=250, letter=False, il=0):
if letter:
il = j0>0
transl = mtransforms.ScaledTranslation(-20 / 72, 15 / 72, fig.dpi_scale_trans)
else:
transl = mtransforms.ScaledTranslation(-20 / 72, 5 / 72, fig.dpi_scale_trans)
for i in range(len(imgs[0])):
ratio = diams[i] / 30.
d = utils.diameters(lbls[i])[0]
Expand All @@ -1608,20 +1712,18 @@ def fig6(folder, save_fig=True):
for j in range(1, 3):
img = np.clip(transforms.normalize99(imgs[j][i].copy().squeeze()), 0, 1)
for k in range(2):
ax = plt.subplot(grid[j, 2 * i + k])
ax = plt.subplot(grid[j+j0, 2 * i + k])
pos = ax.get_position().bounds
ax.set_position([
pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.07,
pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.08*(j0==0),
pos[2], pos[3]
])
if 1:
ax.imshow(img, cmap="gray", vmin=0,
vmax=0.35 if j == 1 and i == 2 else 1.0)
vmax=0.35 if j == 1 and i == 2 and j0==0 else 1.0)
if k == 1:
outlines = utils.outlines_list(masks[j][i],
multiprocessing=False)
#for o in outlines_gt:
# ax.plot(o[:,0], o[:,1], color=[0.7,0.4,1], lw=1, ls="-")
for o in outlines:
ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5,
ls="--")
Expand All @@ -1638,17 +1740,19 @@ def fig6(folder, save_fig=True):
if k == 0 and i == 0:
ax.text(-0.22, 0.5, titlesj[j], transform=ax.transAxes, va="center",
rotation=90, color=colsj[j], fontsize="medium")
if j == 0:
if j==1:
il = plot_label(ltr, il, ax, transl, fs_title)
ax.text(-0.0, 1.22, "Denoising examples from other datasets",
ax.text(-0.02, 1.05, "Denoising examples from other datasets",
fontstyle="italic", transform=ax.transAxes,
fontsize="large")
if k == 0 and j == 0:
ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes,
fontsize="medium")
if save_fig:
os.makedirs("figs/", exist_ok=True)
fig.savefig("figs/fig6.pdf", dpi=150)
if j==1 and letter:
ax.text(-0.0, 1.11, "Deblurring examples from other datasets" if j0==-1 else "Upsampling examples from other datasets",
fontstyle="italic", transform=ax.transAxes,
fontsize="large")
il = plot_label(ltr, il, ax, transl, fs_title)
#if k == 0 and (j == 0 or (j==1 and j0==0)):
#ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes,
# fontsize="medium")

def load_seg_generalist(folder):
folders = [
Expand Down
Loading

0 comments on commit 537c65d

Please sign in to comment.