diff --git a/src/napari_serialcellpose/_tests/test_widget.py b/src/napari_serialcellpose/_tests/test_widget.py index fe9736c..620eb80 100644 --- a/src/napari_serialcellpose/_tests/test_widget.py +++ b/src/napari_serialcellpose/_tests/test_widget.py @@ -1,15 +1,18 @@ from napari_serialcellpose import SerialWidget import numpy as np import pandas as pd - +import pytest from pathlib import Path +import time +import os +import tempfile import shutil def test_load_single_image(make_napari_viewer): viewer = make_napari_viewer() widget = SerialWidget(viewer) - + mypath = Path('src/napari_serialcellpose/_tests/data/single_file_singlechannel/') widget.file_list.update_from_path(mypath) @@ -17,11 +20,11 @@ def test_load_single_image(make_napari_viewer): widget.file_list.setCurrentRow(0) assert len(viewer.layers) == 1 -def test_analyse_single_image_no_save(make_napari_viewer): +def test_analyse_single_image_no_save(qtbot, make_napari_viewer): viewer = make_napari_viewer() widget = SerialWidget(viewer) - + mypath = Path('src/napari_serialcellpose/_tests/data/single_file_singlechannel/') widget.file_list.update_from_path(mypath) @@ -38,24 +41,24 @@ def test_analyse_single_image_no_save(make_napari_viewer): # set diameter and run segmentation widget.spinbox_diameter.setValue(70) widget._on_click_run_on_current() - + # check that segmentatio has been added, named 'mask' and results in 33 objects - assert len(viewer.layers) == 2 + def check_layers(): + assert len(viewer.layers) == 2 + + qtbot.waitUntil(check_layers, timeout=30000) assert viewer.layers[1].name == 'mask' assert viewer.layers[1].data.max() == 33 -def test_analyse_single_image_save(make_napari_viewer): +def test_analyse_single_image_save(qtbot, make_napari_viewer): viewer = make_napari_viewer() widget = SerialWidget(viewer) mypath = Path('src/napari_serialcellpose/_tests/data/single_file_multichannel') - output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_single') - if output_dir.exists(): - shutil.rmtree(output_dir) - output_dir.mkdir(exist_ok=True) - + output_dir = Path(tempfile.mkdtemp()) + widget.file_list.update_from_path(mypath) widget.output_folder = output_dir widget.file_list.setCurrentRow(0) @@ -76,23 +79,26 @@ def test_analyse_single_image_save(make_napari_viewer): widget.check_props['size'].setChecked(True) widget.check_props['intensity'].setChecked(True) widget.qcbox_channel_analysis.setCurrentRow(1) - widget._on_click_run_on_current() - assert len(list(output_dir.glob('*mask.tif'))) == 1 + def check_outputs(): + assert len(list(output_dir.glob('*mask.tif'))) == 1 + + qtbot.waitUntil(check_outputs, timeout=30000) + assert len(list(output_dir.joinpath('tables').glob('*_props.csv'))) == 1 + shutil.rmtree(output_dir) -def test_analyse_multi_image(make_napari_viewer): +def test_analyse_multi_image(qtbot, make_napari_viewer): """Test analysis of multiple images in a folder. No properties are analyzed.""" viewer = make_napari_viewer() widget = SerialWidget(viewer) mypath = Path('src/napari_serialcellpose/_tests/data/multifile/') - output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_multiple') - if output_dir.exists(): - shutil.rmtree(output_dir) - output_dir.mkdir(exist_ok=True) + + output_dir = Path(tempfile.mkdtemp()) + widget.file_list.update_from_path(mypath) widget.output_folder = output_dir @@ -101,23 +107,22 @@ def test_analyse_multi_image(make_napari_viewer): widget.qcbox_model_choice.setCurrentIndex( [widget.qcbox_model_choice.itemText(i) for i in range(widget.qcbox_model_choice.count())].index('cyto2')) widget.spinbox_diameter.setValue(70) - widget._on_click_run_on_current() + widget._on_click_run_on_folder() - assert len(list(output_dir.glob('*mask.tif'))) == 1 + def check_output(): + assert len(list(output_dir.glob('*mask.tif'))) == 4 - widget._on_click_run_on_folder() - assert len(list(output_dir.glob('*mask.tif'))) == 4 + qtbot.waitUntil(check_output, timeout=30000) + shutil.rmtree(output_dir) -def test_analyse_multi_image_props(make_napari_viewer): +def test_analyse_multi_image_props(qtbot, make_napari_viewer): viewer = make_napari_viewer() widget = SerialWidget(viewer) mypath = Path('src/napari_serialcellpose/_tests/data/multifile/') - output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_multiple3') - if output_dir.exists(): - shutil.rmtree(output_dir) - output_dir.mkdir(exist_ok=True) + output_dir = Path(tempfile.mkdtemp()) + widget.file_list.update_from_path(mypath) widget.output_folder = output_dir @@ -133,27 +138,27 @@ def test_analyse_multi_image_props(make_napari_viewer): widget.qcbox_channel_analysis.setCurrentRow(1) widget._on_click_run_on_folder() - assert len(list(output_dir.glob('*mask.tif'))) == 4 - # check that the properties are correct + def check_outputs(): + assert len(list(output_dir.glob('*mask.tif'))) == 4 + + qtbot.waitUntil(check_outputs, timeout=30000) + # check that the properties are correct df = pd.read_csv(output_dir.joinpath( - 'tables', - Path(widget.file_list.currentItem().text()).stem + '_props.csv' - ) + 'tables', + Path(widget.file_list.currentItem().text()).stem + '_props.csv') ) - # check number of columns in df - assert df.shape[1] == 8 + # check number of columns in df + assert df.shape[1] == 8 + -def test_analyse_multichannels(make_napari_viewer): +def test_analyse_multichannels(qtbot, make_napari_viewer): """Test that multiple channels can be used for intensity measurements""" viewer = make_napari_viewer() widget = SerialWidget(viewer) mypath = Path('src/napari_serialcellpose/_tests/data/single_file_multichannel/') - output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_single_multichannelprops') - if output_dir.exists(): - shutil.rmtree(output_dir) - output_dir.mkdir(exist_ok=True) + output_dir = Path(tempfile.mkdtemp()) widget.file_list.update_from_path(mypath) widget.output_folder = output_dir @@ -171,6 +176,11 @@ def test_analyse_multichannels(make_napari_viewer): widget._on_click_run_on_folder() + def check_outputs(): + assert len(list(output_dir.glob('*mask.tif'))) == 1 + + qtbot.waitUntil(check_outputs, timeout=30000) + # check that the properties are correct df = pd.read_csv(output_dir.joinpath( 'tables', @@ -180,16 +190,14 @@ def test_analyse_multichannels(make_napari_viewer): # check number of columns in df assert df.shape[1] == 11 -def test_mask_loading(make_napari_viewer): +def test_mask_loading(qtbot, make_napari_viewer): viewer = make_napari_viewer() widget = SerialWidget(viewer) mypath = Path('src/napari_serialcellpose/_tests/data/multifile/') - output_dir = Path('src/napari_serialcellpose/_tests/data/analyzed_multiple2') - if output_dir.exists(): - shutil.rmtree(output_dir) - output_dir.mkdir(exist_ok=True) + output_dir = Path(tempfile.mkdtemp()) + widget.file_list.update_from_path(mypath) widget.output_folder = output_dir @@ -200,6 +208,12 @@ def test_mask_loading(make_napari_viewer): widget.spinbox_diameter.setValue(70) widget._on_click_run_on_current() + # check that segmentation has been added + def check_layers(): + assert len(viewer.layers) == 3 + + qtbot.waitUntil(check_layers, timeout=30000) + # check that when selecting the second file, we get only 2 channels and no mask widget.file_list.setCurrentRow(1) assert len(viewer.layers) == 2 @@ -208,7 +222,7 @@ def test_mask_loading(make_napari_viewer): widget.file_list.setCurrentRow(0) assert len(viewer.layers) == 3 -def test_analyse_single_image_options_yml(make_napari_viewer): +def test_analyse_single_image_options_yml(qtbot, make_napari_viewer): viewer = make_napari_viewer() widget = SerialWidget(viewer) @@ -232,5 +246,10 @@ def test_analyse_single_image_options_yml(make_napari_viewer): widget._on_click_run_on_current() - # check that because of small diameter from yml file, we get only 5 elements + # check that segmentation has been added + def check_layers(): + assert len(viewer.layers) == 3 + + qtbot.waitUntil(check_layers, timeout=30000) + # check that because of small diameter from yml file, we get only 7 elements assert viewer.layers[2].data.max() == 7 \ No newline at end of file diff --git a/src/napari_serialcellpose/serial_widget.py b/src/napari_serialcellpose/serial_widget.py index 3c04d0a..053b56a 100644 --- a/src/napari_serialcellpose/serial_widget.py +++ b/src/napari_serialcellpose/serial_widget.py @@ -4,6 +4,9 @@ from qtpy.QtCore import Qt import magicgui.widgets from napari.layers import Image +from napari.qt import create_worker, thread_worker +from napari.utils.notifications import show_info + from .folder_list_widget import FolderList from .serial_analysis import run_cellpose, load_props, load_allprops @@ -313,31 +316,37 @@ def _on_click_run_on_current(self): channel_analysis_names = [x.text() for x in self.qcbox_channel_analysis.selectedItems()] reg_props = [k for k in self.check_props.keys() if self.check_props[k].isChecked()] - # run cellpose - segmented, props = run_cellpose( - image_path=image_path, - cellpose_model=self.cellpose_model, - output_path=self.output_folder, - diameter=diameter, - flow_threshold=self.flow_threshold.value(), - cellprob_threshold=self.cellprob_threshold.value(), - clear_border=self.check_clear_border.isChecked(), - channel_to_segment=channel_to_segment, - channel_helper=channel_helper, - channel_measure=channel_analysis, - channel_measure_names=channel_analysis_names, - properties=reg_props, - options_file=self.options_file_path, - force_no_rgb=self.check_no_rgb.isChecked(), - ) - self.viewer.layers.events.inserted.disconnect(self._on_change_layers) - self.viewer.add_labels(segmented, name='mask') - if len(reg_props) > 0: - self.add_table_props(props) + # run cellpose + seg_worker = create_worker(run_cellpose, + image_path=image_path, + cellpose_model=self.cellpose_model, + output_path=self.output_folder, + diameter=diameter, + flow_threshold=self.flow_threshold.value(), + cellprob_threshold=self.cellprob_threshold.value(), + clear_border=self.check_clear_border.isChecked(), + channel_to_segment=channel_to_segment, + channel_helper=channel_helper, + channel_measure=channel_analysis, + channel_measure_names=channel_analysis_names, + properties=reg_props, + options_file=self.options_file_path, + force_no_rgb=self.check_no_rgb.isChecked(), + _progress=True + ) + + def get_seg_worker(output): + self.viewer.add_labels(output[0], name='mask') + if len(reg_props) > 0: + self.add_table_props(output[1]) + self.viewer.layers.events.inserted.connect(self._on_change_layers) + + show_info('Running Segmentation...') + seg_worker.start() + seg_worker.returned.connect(get_seg_worker) - self.viewer.layers.events.inserted.connect(self._on_change_layers) def _on_click_run_on_folder(self): """Run cellpose on all images in folder""" @@ -359,25 +368,31 @@ def _on_click_run_on_folder(self): channel_analysis_names = [x.text() for x in self.qcbox_channel_analysis.selectedItems()] reg_props = [k for k in self.check_props.keys() if self.check_props[k].isChecked()] - for batch in file_list_partition: - _, _ = run_cellpose( - image_path=batch, - cellpose_model=self.cellpose_model, - output_path=self.output_folder, - diameter=diameter, - flow_threshold=self.flow_threshold.value(), - cellprob_threshold=self.cellprob_threshold.value(), - clear_border=self.check_clear_border.isChecked(), - channel_to_segment=channel_to_segment, - channel_helper=channel_helper, - channel_measure=channel_analysis, - channel_measure_names=channel_analysis_names, - properties=reg_props, - options_file=self.options_file_path, - force_no_rgb=self.check_no_rgb.isChecked(), - ) - - self._on_click_load_summary() + @thread_worker(progress={'total': len(file_list_partition), 'desc': 'Running batch segmentation'}) + def run_batch(file_list_partition): + for batch in file_list_partition: + yield run_cellpose( + image_path=batch, + cellpose_model=self.cellpose_model, + output_path=self.output_folder, + diameter=diameter, + flow_threshold=self.flow_threshold.value(), + cellprob_threshold=self.cellprob_threshold.value(), + clear_border=self.check_clear_border.isChecked(), + channel_to_segment=channel_to_segment, + channel_helper=channel_helper, + channel_measure=channel_analysis, + channel_measure_names=channel_analysis_names, + properties=reg_props, + options_file=self.options_file_path, + force_no_rgb=self.check_no_rgb.isChecked(), + ) + + show_info('Running Segmentation...') + batch_worker = run_batch(file_list_partition) + batch_worker.start() + + batch_worker.returned.connect(self._on_click_load_summary) def get_channels_to_use(self): """Translate selected channels in QCombox into indices. @@ -609,4 +624,4 @@ def __init__(self, parent=None, col=1, row=1, width=6, height=4, dpi=100): for j in range(col): self.ax[i,j] = fig.add_subplot(row, col, count) count+=1 - super(MplCanvas, self).__init__(fig) \ No newline at end of file + super(MplCanvas, self).__init__(fig)