From fd8c09862f51ff1d4f02531ddec4b38482320d8b Mon Sep 17 00:00:00 2001 From: Tony Tung Date: Tue, 8 Oct 2019 15:30:58 -0700 Subject: [PATCH] Fix bugs in per-round-max-decoder 1. When the entire row is nan, the decoder chokes. This is remedied by decoding on an array where the nan values are replaced with 0s. 2. When the entire row is of equal intensity, the `np.argmax` arbitrarily picks the first column as the winner. That erroneously decodes as ch=0 having the max intensity. This code detects that scenario, and rewrites the ch to an impossible value in that situation. Test plan: Wrote tests that failed with the existing code, applied fixes and verified that they now work. Fixes #1485 --- notebooks/ISS.ipynb | 2 +- starfish/core/codebook/codebook.py | 11 +++++- .../test/test_per_round_max_decode.py | 39 +++++++++++++++++++ .../test/full_pipelines/api/test_iss_api.py | 2 +- 4 files changed, 51 insertions(+), 3 deletions(-) diff --git a/notebooks/ISS.ipynb b/notebooks/ISS.ipynb index d6bf6dbc8..ad5fe1b54 100644 --- a/notebooks/ISS.ipynb +++ b/notebooks/ISS.ipynb @@ -450,4 +450,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/starfish/core/codebook/codebook.py b/starfish/core/codebook/codebook.py index 62a912bb5..c531625db 100644 --- a/starfish/core/codebook/codebook.py +++ b/starfish/core/codebook/codebook.py @@ -651,7 +651,16 @@ def _view_row_as_element(array: np.ndarray) -> np.ndarray: distances=(Features.AXIS, np.empty(0, dtype=np.float64)), passes_threshold=(Features.AXIS, np.empty(0, dtype=bool))) - max_channels = intensities.argmax(Axes.CH.value) + intensities_without_nans = intensities.fillna(0) + max_channels = intensities_without_nans.argmax(Axes.CH.value) + # this snippet of code finds all the (feature, round) spots that have uniform illumination, + # and assigns them to a ch number that's one larger than max possible to ensure that such + # spots decode to `NaN`. + max_channels_max = intensities_without_nans.reduce(np.amax, Axes.CH.value) + max_channels_min = intensities_without_nans.reduce(np.amin, Axes.CH.value) + uniform_illumination_mask = (max_channels_max == max_channels_min).values + + max_channels.values[uniform_illumination_mask] = intensities.sizes[Axes.CH.value] codes = self.argmax(Axes.CH.value) # TODO ambrosejcarr, dganguli: explore this quality score further diff --git a/starfish/core/codebook/test/test_per_round_max_decode.py b/starfish/core/codebook/test/test_per_round_max_decode.py index d7cd5a07f..29077cec3 100644 --- a/starfish/core/codebook/test/test_per_round_max_decode.py +++ b/starfish/core/codebook/test/test_per_round_max_decode.py @@ -155,3 +155,42 @@ def test_argmax_selects_the_last_equal_intensity_channel_and_decodes_consistentl decoded_intensities = codebook.decode_per_round_max(intensities) assert np.array_equal(decoded_intensities[Features.TARGET].values, ['nan', 'GENE_A']) + + +def test_argmax_does_not_select_first_code(): + """ + When all the channels in a round are uniform, argmax erroneously picks the first channel as the + max. In this case, it incorrectly assigns the wrong code for that round. This test ensures + that the workaround we put in for this works correctly. + """ + + data = np.array( + [[[0.0, 1.0], + [1.0, 1.0]], # this round is uniform, so it will erroneously be decoded as the first ch. + [[0.0, 1.0], + [1.0, 0.0]]] + ) + intensities = intensity_table_factory(data) + codebook = codebook_factory() + + decoded_intensities = codebook.decode_per_round_max(intensities) + assert np.array_equal(decoded_intensities[Features.TARGET].values, ['nan', 'GENE_A']) + + +def test_feature_round_all_nan(): + """ + When all the channels in a round are NaN, argmax chokes. This test ensures that the workaround + we put in for this works correctly. + """ + + data = np.array( + [[[0.0, 1.0], + [np.nan, np.nan]], + [[0.0, 1.0], + [1.0, 0.0]]] + ) + intensities = intensity_table_factory(data) + codebook = codebook_factory() + + decoded_intensities = codebook.decode_per_round_max(intensities) + assert np.array_equal(decoded_intensities[Features.TARGET].values, ['nan', 'GENE_A']) diff --git a/starfish/test/full_pipelines/api/test_iss_api.py b/starfish/test/full_pipelines/api/test_iss_api.py index b197ce896..7fb2ec7f8 100644 --- a/starfish/test/full_pipelines/api/test_iss_api.py +++ b/starfish/test/full_pipelines/api/test_iss_api.py @@ -105,7 +105,7 @@ def test_iss_pipeline_cropped_data(tmpdir): assert np.array_equal(genes, np.array(['ACTB', 'CD68', 'CTSL2', 'EPCAM', 'ETV4', 'GAPDH', 'GUS', 'HER2', 'RAC1', 'TFRC', 'TP53', 'VEGF'])) - assert np.array_equal(gene_counts, [20, 1, 5, 2, 1, 11, 1, 3, 2, 1, 1, 2]) + assert np.array_equal(gene_counts, [19, 1, 5, 2, 1, 11, 1, 3, 2, 1, 1, 2]) masks = iss.masks