From 537c65d329a61d770dcc5f192331a0b1f0a8ba0a Mon Sep 17 00:00:00 2001 From: Carsen Stringer Date: Thu, 21 Nov 2024 09:40:01 -0500 Subject: [PATCH] adding other paper code --- cellpose/denoise.py | 2 +- paper/3.0/analysis.py | 101 ++++++++++++++++++++ paper/3.0/figures.py | 186 +++++++++++++++++++++++++++++-------- paper/3.0/train_subsets.py | 133 ++++++++++++++++++++++++++ 4 files changed, 380 insertions(+), 42 deletions(-) create mode 100644 paper/3.0/train_subsets.py diff --git a/cellpose/denoise.py b/cellpose/denoise.py index 1f89db67..f925a71d 100644 --- a/cellpose/denoise.py +++ b/cellpose/denoise.py @@ -23,7 +23,7 @@ MODEL_NAMES = [] for ctype in ["cyto3", "cyto2", "nuclei"]: - for ntype in ["denoise", "deblur", "upsample"]: + for ntype in ["denoise", "deblur", "upsample", "oneclick"]: MODEL_NAMES.append(f"{ntype}_{ctype}") if ctype != "cyto3": for ltype in ["per", "seg", "rec"]: diff --git a/paper/3.0/analysis.py b/paper/3.0/analysis.py index 2a9eb1de..cf4b4ef5 100644 --- a/paper/3.0/analysis.py +++ b/paper/3.0/analysis.py @@ -9,6 +9,7 @@ from pathlib import Path import torch from torch import nn +from tqdm import trange # in same folder try: @@ -299,6 +300,106 @@ def real_examples(folder): dat2["ap_n2v"] = ap_n2v np.save(root / "n2v_masks.npy", dat2) +def real_examples_ribo(root): + navgs = [1, 2, 4, 8, 16, 32, 64] + noisy = [[], [], [], [], [], [], []] + clean = [] + for i in [1, 3, 6, 4, 5]: + imgs = io.imread(Path(root) / f"denoise_{i:05d}_00001.tif")[:300] + imgs = [imgs[:, :512, :512], imgs[:, 512:, :512], imgs[:, :512, 512:], imgs[:, 512:, 512:]] + clean.extend([img.mean(axis=0) for img in imgs]) + for n, navg in enumerate(navgs): + iavg = np.linspace(0, len(imgs[0])-1, navg+2).astype(int)[1:-1] + noisy[n].extend(np.array([img[iavg].mean(axis=0) for img in imgs])) + print(len(clean), len(noisy[0])) + + thresholds = np.arange(0.5, 1.05, 0.05) + diameter = 17 + normalize = True # {"tile_norm_blocksize": 80} + seg_model = models.Cellpose(gpu=True, model_type="cyto2") + model = denoise.DenoiseModel(gpu=True, model_type="denoise_cyto2") + masks = seg_model.eval(clean, diameter=diameter, channels=[0,0], + normalize=normalize)[0] + ap_noisy = np.zeros((len(noisy), len(noisy[0]), len(thresholds))) + ap_dn = np.zeros((len(noisy), len(noisy[0]), len(thresholds))) + dat = {} + dat["navgs"] = navgs + dat["imgs_dn"] = [] + dat["masks_dn"] = [] + dat["masks_noisy"] = [] + dat["masks_clean"] = masks + dat["noisy"] = noisy + dat["clean"] = clean + for n, imgs in enumerate(noisy): + masks_noisy = seg_model.eval(imgs, diameter=diameter, channels=[0,0], + normalize=normalize)[0] + img_dn = model.eval(imgs, diameter=diameter, channels=[0,0], + normalize=normalize) + ap, tp, fp, fn = metrics.average_precision(masks, masks_noisy, threshold=thresholds) + ap_noisy[n] = ap + masks_dn = seg_model.eval(img_dn, diameter=diameter, channels=[0,0], + normalize=normalize)[0] + ap, tp, fp, fn = metrics.average_precision(masks, masks_dn, threshold=thresholds) + ap_dn[n] = ap + dat["imgs_dn"].append(img_dn) + dat["masks_dn"].append(masks_dn) + dat["masks_noisy"].append(masks_noisy) + print(ap_noisy[n,:,0].mean(axis=0), ap_dn[n,:,0].mean(axis=0)) + dat["ap_noisy"] = ap_noisy + dat["ap_dn"] = ap_dn + np.save(Path(root) / "ribo_denoise.npy", dat) + + dat = {} + dat["navgs"] = navgs + dat["imgs_n2s"] = [] + dat["masks_n2s"] = [] + dat["masks_clean"] = masks + dat["noisy"] = noisy + dat["clean"] = clean + dat["ap_n2s"] = np.zeros((len(noisy), len(noisy[0]), len(thresholds))) + + for n, imgs in enumerate(noisy): + imgs_n2s = [] + for i in trange(len(imgs)): + out = noise2self.train_per_image(imgs[i][np.newaxis,...].astype("float32")) + imgs_n2s.append(out) + imgs_n2s = np.array(imgs_n2s) + masks_n2s = seg_model.eval(imgs_n2s, diameter=diameter, channels=[0,0])[0] + ap, tp, fp, fn = metrics.average_precision(masks, masks_n2s, threshold=thresholds) + dat["ap_n2s"][n] = ap + dat["imgs_n2s"].append(imgs_n2s) + dat["masks_n2s"].append(masks_n2s) + print(n, ap.mean(axis=0)[[0, 5, 8]]) + + np.save(Path(root) / "ribo_denoise_n2s.npy", dat) + + dat = {} + dat["navgs"] = navgs + dat["imgs_n2v"] = [] + dat["masks_n2v"] = [] + dat["masks_clean"] = masks + dat["noisy"] = noisy + dat["clean"] = clean + dat["ap_n2v"] = np.zeros((len(noisy), len(noisy[0]), len(thresholds))) + + for n, imgs in enumerate(noisy): + imgs_n2v = [] + for i in trange(len(imgs)): + out = noise2void.train_per_image(imgs[i].astype("float32")) + imgs_n2v.append(out) + imgs_n2v = np.array(imgs_n2v) + masks_n2v = seg_model.eval(imgs_n2v, diameter=diameter, channels=[0,0], + normalize=normalize)[0] + ap, tp, fp, fn = metrics.average_precision(masks, masks_n2v, threshold=thresholds) + #print(ap[:,0]) + dat["ap_n2v"][n] = ap + dat["imgs_n2v"].append(imgs_n2v) + dat["masks_n2v"].append(masks_n2v) + print(n, ap.mean(axis=0)[[0, 5, 8]]) + + np.save(Path(root) / "ribo_denoise_n2v.npy", dat) + + def specialist_training(root): """ root is path to specialist images (first 89 images of cyto2 and first 11 test images) """ diff --git a/paper/3.0/figures.py b/paper/3.0/figures.py index 02bef712..a3cd6c5d 100644 --- a/paper/3.0/figures.py +++ b/paper/3.0/figures.py @@ -1204,9 +1204,9 @@ def suppfig_specialist(folder, save_fig=True): il = 0 - fig = plt.figure(figsize=(9, 5), dpi=100) + fig = plt.figure(figsize=(9, 9), dpi=100) yratio = 9 / 5 - grid = plt.GridSpec(2, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1, + grid = plt.GridSpec(3, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1, wspace=0.15, hspace=0.2) titles = ["train - clean", "train - noisy", "test - noisy"] @@ -1265,32 +1265,46 @@ def suppfig_specialist(folder, save_fig=True): ax.set_xticks(np.arange(0.5, 1.05, 0.1)) ax.set_xlim([0.5, 1.0]) - transl = mtransforms.ScaledTranslation(-10 / 72, 20 / 72, fig.dpi_scale_trans) + grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[1:, :], wspace=0.05, + hspace=0.1) - kk = [2, 3, 4, 10] + transl = mtransforms.ScaledTranslation(-10 / 72, 25 / 72, fig.dpi_scale_trans) + + kk = [2, 3, 4, 6, 10] iex = 8 - ylim = [10, 310] - xlim = [100, 500] + ylim = [125, 512] # [0, 350] + xlim = [50, 325] # [100, 500] legstr0[-1] = u"\u2013 Cellpose3 (per. + seg.)" for j, k in enumerate(kk): - ax = plt.subplot(grid[1, j]) - pos = ax.get_position().bounds - ax.set_position([pos[0], pos[1] - 0.07, pos[2], pos[3]]) - img0 = imgs_all[k][iex].squeeze() - img0 *= 1.1 - img0 = np.clip(img0, 0, 1) + outlines_gt = utils.outlines_list(masks_all[0][iex].T.copy(), multiprocessing=False) + for ii in range(2): + ax = plt.subplot(grid1[ii, j]) + pos = ax.get_position().bounds + ax.set_position([pos[0], pos[1] - 0.07 + ii*0.03, pos[2], pos[3]]) + img0 = imgs_all[k][iex].squeeze().T + masks0 = masks_all[k][iex].squeeze().T + img0 *= 1. + img0 = np.clip(img0, 0, 1) - ax.imshow(img0, cmap="gray", vmin=0, vmax=1) - ax.axis("off") - ax.set_ylim(ylim) - ax.set_xlim(xlim) - ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium") - ax.text(1, -0.04, f"AP@0.5 = {aps[k,iex,0] : 0.2f}", va="top", ha="right", - transform=ax.transAxes) - if j == 0: - il = plot_label(ltr, il, ax, transl, fs_title) - ax.text(0.02, 1.2, "Denoised test image", fontsize="large", - fontstyle="italic", transform=ax.transAxes) + ax.imshow(img0, cmap="gray", vmin=0, vmax=1) + if ii==1: + outlines = utils.outlines_list(masks0, multiprocessing=False) + for o in outlines_gt: + ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=2) + for o in outlines: + ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--") + ax.axis("off") + ax.set_ylim(ylim) + ax.set_xlim(xlim) + if ii==0: + ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium") + else: + ax.text(1, -0.04, f"AP@0.5 = {aps[k,iex,0] : 0.2f}", va="top", ha="right", + transform=ax.transAxes) + if j == 0 and ii==0: + il = plot_label(ltr, il, ax, transl, fs_title) + ax.text(0.02, 1.15, "Denoised test image", fontsize="large", + fontstyle="italic", transform=ax.transAxes) print(aps.mean(axis=1)[:, [0, 5, 8]]) @@ -1493,9 +1507,9 @@ def fig6(folder, save_fig=True): diams = [utils.diameters(lbl)[0] for lbl in lbls] - gen_model = "/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039" + gen_model = "oneclick_cyto3" #"/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039" model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, - pretrained_model=gen_model) + model_type=gen_model) seg_model = models.CellposeModel(gpu=True, model_type="cyto3") pscales = [1.5, 20., 1.5, 1., 5., 40., 3.] denoise.deterministic() @@ -1561,6 +1575,7 @@ def fig6(folder, save_fig=True): legstr0 = ["", u"\u2013 noisy image", u"\u2013 original", u"\u2013 noise-specific", "\u2013 data-specific", u"-- one-click"] theight = [0, 0,4,3,2,1] + cstr = ["noisy\nimage", "blurry\nimage", "bilinear\nupsampled"] for i in range(6): ctype = "cellpose test set" if i < 3 else "nuclei test set" noise_type = ["denoising", "deblurring", "upsampling"][i % 3] @@ -1580,7 +1595,7 @@ def fig6(folder, save_fig=True): if i == 1 or i == 4: ax.text(0.5, 1.18, ctype, transform=ax.transAxes, ha="center", fontsize="large") - + ax.text(0.03, 0.03, cstr[i%3], transform=ax.transAxes, fontsize="small") ax.set_ylim([0, 0.72]) ax.set_xticks(np.arange(0.5, 1.05, 0.25)) ax.set_xlim([0.5, 1.0]) @@ -1593,9 +1608,98 @@ def fig6(folder, save_fig=True): ] colsj = cols0[[0, 1, -1]] - ly0 = 250 + generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api, + titlesj, colsj, titlesi, j0=0, il=il) + + if save_fig: + os.makedirs("figs/", exist_ok=True) + fig.savefig("figs/fig6.pdf", dpi=150) + +def suppfig_generalist_examples(folder, save_fig=True): + cols0 = np.array([[0, 0, 0], [0, 0, 0], [0, 128, 0], [180, 229, 162], + [246, 198, 173], [192, 71, 29], ]) + cols0 = cols0 / 255 + titlesi = [ + "Tissuenet", "Livecell", "Yeaz bright-field", "YeaZ phase-contrast", + "Omnipose phase-contrast", "Omnipose fluorescent", "DeepBacs" + ] + colsj = cols0[[0, 1, -1]] + folders = [ + "cyto2", "nuclei", "tissuenet", "livecell", "yeast_BF", "yeast_PhC", + "bact_phase", "bact_fluor", "deepbacs" + ] + diam_mean = 30. + + #iexs = [340, 50, 10, 5, 70, 2, 33] + iexs = [305, 1071, 0, 3, 70, 9, 31] + imgs, lbls = [[], [], []], [] + masks = [[], [], []] + for f, iex in zip(folders[2:], iexs): + dat = np.load(Path(folder) / f"{f}_generalist_masks.npy", + allow_pickle=True).item() + img = dat["imgs"][iex].copy() + img = img[:1] if img.ndim > 2 else img + img = np.maximum(0, transforms.normalize99(img)) + imgs[0].append(img) + masks[0].append(dat["masks_pred"][iex]) + lbls.append(dat["masks"][iex].astype("uint16")) + + diams = [utils.diameters(lbl)[0] for lbl in lbls] - transl = mtransforms.ScaledTranslation(-15 / 72, 30 / 72, fig.dpi_scale_trans) + gen_model = "oneclick_cyto3" + model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, + model_type=gen_model) + seg_model = models.CellposeModel(gpu=True, model_type="cyto3") + + fig = plt.figure(figsize=(14, 8), dpi=100) + grid = plt.GridSpec(4, 14, figure=fig, left=0.02, right=0.97, top=0.97, bottom=0.03) + + for ii in range(2): + if ii==0: + titlesj = ["clean", "blurry", "deblurred (one-click)"] + else: + titlesj = ["clean", "downsampled", "upsampled (one-click)"] + masks[1] = [] + masks[2] = [] + imgs[1] = [] + imgs[2] = [] + sigmas = [5., 3., 7., 12., 5., 5., 3.] + ds = [6,4,8,8,6,6,6] + denoise.deterministic() + for i, img in tqdm(enumerate(imgs[0])): + img0 = torch.from_numpy(img.copy()).squeeze().unsqueeze(0).unsqueeze(0) + img0 = img0.float() + noisy0 = denoise.add_noise(img0, poisson=0., downsample=1. if ii==1 else 0, + blur=1., ds=ds[i] if ii==1 else 0, + sigma0 = sigmas[i] if ii==0 else sigmas[i]/2, + sigma1 = sigmas[i] if ii==0 else sigmas[i]/2, + pscale=120.).numpy().squeeze() + denoised0 = model.eval(noisy0, diameter=diams[i], normalize=True) + + imgs[1].append(noisy0) + imgs[2].append(denoised0) + for j in range(1, 3): + masks[j].append( + seg_model.eval( + imgs[j][i], diameter=diams[i], channels=[0, 0], tile_overlap=0.5, + flow_threshold=0.4, augment=True, bsize=224, + niter=2000 if folders[i - 2] == "bact_phase" else None)[0]) + api = np.array( + [metrics.average_precision(lbls, masks[i])[0][:, 0] for i in range(3)]) + + generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api, + titlesj, colsj, titlesi, j0=-1 + 2*ii, letter=True) + if save_fig: + os.makedirs("figs/", exist_ok=True) + fig.savefig("figs/suppfig_genex.pdf", dpi=150) + +def generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api, + titlesj, colsj, titlesi, j0=0, ly0=250, letter=False, il=0): + if letter: + il = j0>0 + transl = mtransforms.ScaledTranslation(-20 / 72, 15 / 72, fig.dpi_scale_trans) + else: + transl = mtransforms.ScaledTranslation(-20 / 72, 5 / 72, fig.dpi_scale_trans) for i in range(len(imgs[0])): ratio = diams[i] / 30. d = utils.diameters(lbls[i])[0] @@ -1608,20 +1712,18 @@ def fig6(folder, save_fig=True): for j in range(1, 3): img = np.clip(transforms.normalize99(imgs[j][i].copy().squeeze()), 0, 1) for k in range(2): - ax = plt.subplot(grid[j, 2 * i + k]) + ax = plt.subplot(grid[j+j0, 2 * i + k]) pos = ax.get_position().bounds ax.set_position([ - pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.07, + pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.08*(j0==0), pos[2], pos[3] ]) if 1: ax.imshow(img, cmap="gray", vmin=0, - vmax=0.35 if j == 1 and i == 2 else 1.0) + vmax=0.35 if j == 1 and i == 2 and j0==0 else 1.0) if k == 1: outlines = utils.outlines_list(masks[j][i], multiprocessing=False) - #for o in outlines_gt: - # ax.plot(o[:,0], o[:,1], color=[0.7,0.4,1], lw=1, ls="-") for o in outlines: ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--") @@ -1638,17 +1740,19 @@ def fig6(folder, save_fig=True): if k == 0 and i == 0: ax.text(-0.22, 0.5, titlesj[j], transform=ax.transAxes, va="center", rotation=90, color=colsj[j], fontsize="medium") - if j == 0: + if j==1: il = plot_label(ltr, il, ax, transl, fs_title) - ax.text(-0.0, 1.22, "Denoising examples from other datasets", + ax.text(-0.02, 1.05, "Denoising examples from other datasets", fontstyle="italic", transform=ax.transAxes, fontsize="large") - if k == 0 and j == 0: - ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes, - fontsize="medium") - if save_fig: - os.makedirs("figs/", exist_ok=True) - fig.savefig("figs/fig6.pdf", dpi=150) + if j==1 and letter: + ax.text(-0.0, 1.11, "Deblurring examples from other datasets" if j0==-1 else "Upsampling examples from other datasets", + fontstyle="italic", transform=ax.transAxes, + fontsize="large") + il = plot_label(ltr, il, ax, transl, fs_title) + #if k == 0 and (j == 0 or (j==1 and j0==0)): + #ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes, + # fontsize="medium") def load_seg_generalist(folder): folders = [ diff --git a/paper/3.0/train_subsets.py b/paper/3.0/train_subsets.py new file mode 100644 index 00000000..162c66de --- /dev/null +++ b/paper/3.0/train_subsets.py @@ -0,0 +1,133 @@ +import time, os +import numpy as np +from cellpose import io, transforms, utils, models, dynamics, metrics, resnet_torch, denoise +from cellpose.transforms import normalize_img +from pathlib import Path +import torch +from torch import nn +import time +import argparse + +def main(): + parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters") + parser.add_argument("--nsub", default=0, type=int) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument("--n_epochs", default=2000, type=int) + + args = parser.parse_args() + n_epochs = args.n_epochs + nsub = args.nsub + seed = args.seed + print(n_epochs, nsub, seed) + + root = Path("/groups/stringer/stringerlab/datasets_cellpose/images_cyto2/") + #pretrained_model = str(root / "cyto3") + batch_size = 8 + + io.logger_setup() + + device = torch.device("cuda") + ntrain = 796 + np.random.seed(seed) + if nsub > 0: + iperm = np.random.permutation(ntrain)[:nsub] + else: + iperm = np.arange(ntrain) + + # keep green channel + train_data = [] + for i in iperm: + img = io.imread(root / "train" / f"{i:03d}_img.tif") + if img.ndim > 2: + img = img[0] + train_data.append(np.maximum(transforms.normalize99(img), 0)[np.newaxis,:,:]) + + train_labels = [io.imread(root / "train" / f"{i:03d}_img_flows.tif") for i in iperm] + + test_data = [] + for i in range(68): + img = io.imread(root / "test" / f"{i:03d}_img.tif") + if img.ndim > 2: + img = img[0] + test_data.append(np.maximum(transforms.normalize99(img), 0)[np.newaxis,:,:]) + test_labels = [io.imread(root / "test" / f"{i:03d}_img_flows.tif") for i in range(91)] + + model = denoise.DenoiseModel(gpu=True, nchan=1, pretrained_model=None) + + (root / "models").mkdir(exist_ok=True) + # poisson training + model_path, train_losses, test_losses = denoise.train(model.net, train_data=train_data, train_labels=train_labels, + test_data=test_data, test_labels=test_labels, + save_path=root / "models", blur=0., gblur=0.5, + iso=True, lam=[1.,1.5,0], + downsample=0., poisson=0.8, beta=0.7 , n_epochs=n_epochs, + learning_rate=0.001, weight_decay=1e-5, + seg_model_type="cyto2", model_name=f"denoise_cyto2_{nsub}_{seed}_{n_epochs}") + + +if __name__ == "__main__": + main() + +def save_results(folder, sroot): + nsubs = 2 ** np.arange(1, 10) + nsubs = np.vstack((nsubs, 796)) + thresholds = np.arange(0.5, 1.05, 0.05) + seg_model = models.CellposeModel(gpu=True, model_type="cyto2") + noise_type = "poisson" + ctype = "cyto2" + + folder_name = ctype + diam_mean = 30 + root = Path(folder) / f"images_{folder_name}/" + model_name = "cyto2" + + ### cellpose enhance + dat = np.load(root / "noisy_test" / f"test_{noise_type}.npy", + allow_pickle=True).item() + test_noisy = dat["test_noisy"][:68] + masks_true = dat["masks_true"][:68] + diam_test = dat["diam_test"][:68] if "diam_test" in dat else 30. * np.ones( + len(test_noisy)) + + aps = np.zeros((len(nsubs), 5, len(thresholds))) + for k, nsub in enumerate(nsubs): + if nsub != 796: + continue + ni = nsub if nsub < 796 else 0 + for seed in range(5): + si = seed if nsub < 796 else f"{seed}_2000" + dn_model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, + pretrained_model=str(sroot / f"denoise_cyto2_{ni}_{si}")) + imgs2 = dn_model.eval([test_noisy[i][0] for i in range(len(test_noisy))], + diameter=diam_test, channel_axis=0) + + masks2, flows2, styles2 = seg_model.eval(imgs2, channels=[1, 0], + diameter=diam_test, channel_axis=-1, + normalize=True) + + ap, tp, fp, fn = metrics.average_precision(masks_true, masks2, threshold=thresholds) + aps[k, seed] = ap.mean(axis=0) + print(f"{nsub} AP@0.5 \t = {ap[:,0].mean(axis=0):.3f}") + + np.save("nsubs_aps.npy", aps) + + n_epochss = np.array([100, 200, 400, 800, 1600, 2000, 3200]) + nsub, seed = 0, 0 + thresholds = np.arange(0.5, 1.05, 0.05) + aps = np.zeros((len(n_epochss), 5, len(thresholds))) + for k, n_epochs in enumerate(n_epochss): + for seed in range(5): + dn_model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, + pretrained_model=str(sroot / f"denoise_cyto2_{nsub}_{seed}_{n_epochs}")) + imgs2 = dn_model.eval([test_noisy[i][0] for i in range(len(test_noisy))], + diameter=diam_test, channel_axis=0) + + masks2, flows2, styles2 = seg_model.eval(imgs2, channels=[1, 0], + diameter=diam_test, channel_axis=-1, + normalize=True) + + ap, tp, fp, fn = metrics.average_precision(masks_true, masks2, threshold=thresholds) + aps[k, seed] = ap.mean(axis=0) + print(f"{n_epochs} AP@0.5 \t = {ap[:,0].mean(axis=0):.3f}") + + np.save("n_epochs_aps.npy", aps) \ No newline at end of file