From 0c58332c6cfd8a22e0de79ca46fe61187e955b2a Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Mon, 13 May 2024 22:01:59 -0400 Subject: [PATCH 01/13] Add cell matching methods. --- brainglobe_utils/cells/cells.py | 348 ++++++++++++++++++++++++- pyproject.toml | 1 + tests/tests/test_cells/test_matches.py | 73 ++++++ 3 files changed, 421 insertions(+), 1 deletion(-) create mode 100644 tests/tests/test_cells/test_matches.py diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 1b7b8cc..9de1778 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -6,12 +6,25 @@ import math import os import re +import threading from collections import defaultdict from functools import total_ordering -from typing import Any, DefaultDict, Dict, List, Tuple, Union +from typing import ( + Any, + DefaultDict, + Dict, + List, + Optional, + Tuple, + Union, +) from xml.etree import ElementTree from xml.etree.ElementTree import Element as EtElement +import numpy as np +from numba import njit, objmode +from tqdm import tqdm + @total_ordering class Cell: @@ -381,3 +394,336 @@ class MissingCellsError(Exception): """Custom exception class for when no cells are found in a file""" pass + + +def to_numpy_pos( + cells: List[Cell], cell_type: Optional[int] = None +) -> np.ndarray: + """ + Takes a list of Cell objects, selects only cells of type `cell_type` (if + not None) and returns a single 2d array of shape Nx3 with the + positions of the cells. + """ + # for large cell list, pre-compute size + n = len(cells) + if cell_type is not None: + n = sum([cell.type == cell_type for cell in cells]) + np_cells = np.empty((n, 3), dtype=np.float_) + + i = 0 + for cell in cells: + if cell_type is not None and cell.type != cell_type: + continue + np_cells[i, :] = cell.x, cell.y, cell.z + i += 1 + + return np_cells + + +def from_numpy_pos(pos: np.ndarray, cell_type: int) -> List[Cell]: + """ + Takes a 2d numpy position array of shape Nx3 and returns a list of Cell + objects of given cell_type from those positions. + """ + cells = [] + for i in range(pos.shape[0]): + cell = Cell(pos=pos[i, :].tolist(), cell_type=cell_type) + cells.append(cell) + + return cells + + +def match_cells( + cells: List[Cell], other: List[Cell], threshold: float = np.inf +) -> Tuple[List[int], List[Tuple[int, int]], List[int]]: + """ + Given two lists of cells. It finds a pairing of cells from `cells` and + `other` such that the distance (euclidian) between the assigned matches + across all `cells` is minimized. + + Remaining cells (e.g. if one list is longer or if there are matches + violating the threshold) are indicated as well. + + E.g.:: + + >>> cells = [ + >>> Cell([20, 20, 20], Cell.UNKNOWN), + >>> Cell([10, 10, 10], Cell.UNKNOWN), + >>> Cell([40, 40, 40], Cell.UNKNOWN), + >>> Cell([50, 50, 50], Cell.UNKNOWN), + >>> ] + >>> other = [ + >>> Cell([5, 5, 5], Cell.UNKNOWN), + >>> Cell([15, 15, 15], Cell.UNKNOWN), + >>> Cell([35, 35, 35], Cell.UNKNOWN), + >>> Cell([100, 100, 100], Cell.UNKNOWN), + >>> Cell([200, 200, 200], Cell.UNKNOWN), + >>> ] + >>> match_cells(cells, other, threshold=20) + ([3], [[0, 1], [1, 0], [2, 2]], [3, 4]) + + Parameters + ---------- + cells : list of Cells. + other : Another list of Cells. + threshold : float, optional. Defaults to np.inf. + The threshold to use to remove bad matches. Any match pair whose + distance is greater than the threshold will be exluded from the + matching. + + Returns + ------- + tuple : + missing_cells: List of all the indices of `cells` that found no match + in `other` (sorted). + good_matches: List of tuples with all the (cells, other) indices pairs + that matched below the threshold. It's sorted by the `cells` + column. + missing_other: List of all the indices of `other` that found no match + in `cells` (sorted). + """ + if __progress_update.updater is not None: + # I can't think of an instance where this will happen, but better safe + raise TypeError( + "An instance of match_cells is already running in this " + "thread. Try running again once it completes" + ) + c1 = to_numpy_pos(cells) + c2 = to_numpy_pos(other) + + # c1 must be smaller or equal in length than c2 + flip = len(cells) > len(other) + if flip: + c1, c2 = c2, c1 + + progress = tqdm(desc="Matching cells", total=len(c1), unit="cells") + __progress_update.updater = progress.update + # for each index corresponding to c1, returns the index in c2 that matches + assignment = match_points(c1, c2) + progress.close() + __progress_update.updater = None + + missing_c1, good_matches, missing_c2 = analyze_point_matches( + c1, c2, assignment, threshold + ) + if flip: + missing_c1, missing_c2 = missing_c2, missing_c1 + good_matches = np.flip(good_matches, axis=1) + good_matches = good_matches[good_matches[:, 0].argsort()] + + return missing_c1.tolist(), good_matches.tolist(), missing_c2.tolist() + + +# terrible hack. But you can't pass arbitrary objects to a njit function. But, +# it can access global variables and run them in objmode. So pass the progress +# updater to match_points via this global variable and function. We make it +# thread safe nominally, but it's not safe to modify within a thread while +# match_points is running + +__progress_update = threading.local() +__progress_update.updater = None + + +def __compare_progress(): + if __progress_update.updater is not None: + __progress_update.updater() + + +@njit +def match_points(pos1: np.ndarray, pos2: np.ndarray) -> np.ndarray: + """ + Given two arrays, each a list of position. For each point in `pos1` it + finds a point in `pos2` such that the distance between the assigned + matches across all `pos1` is minimized. + + E.g.:: + + >>> pos1 = np.array([[20, 10, 30, 40]]).T + >>> pos2 = np.array([[5, 15, 25, 35, 50]]).T + >>> matches = match_points(pos1, pos2) + >>> matches + array([1, 0, 2, 3]) + + Parameters + ---------- + pos1 : np.ndarray + 2D array of NxK. Where N is number of positions and K is the number + of dimensions (e.g. 3 for x, y, z). + pos2 : np.ndarray + 2D array of MxK. Where M is number of positions and K is the number + of dimensions (e.g. 3 for x, y, z). + + The relationship N <= M must be true. + + Returns + ------- + matches : np.ndarray + 1D array of length N. Each index i in matches corresponds + to index i in `pos1`. The value of index i in matches is the index + j in pos2 that is the best match for that pos1. + + I.e. the match is (pos1[i], pos2[matches[i]]). + """ + # based on https://en.wikipedia.org/wiki/Hungarian_algorithm + n_rows = pos1.shape[0] + n_cols = pos2.shape[0] + if n_rows > n_cols: + raise ValueError( + "The length of pos1 must be less than or equal to length of pos2" + ) + + potentials_rows = np.zeros(n_rows) + potentials_cols = np.zeros(n_cols + 1) + assignment_row = np.full(n_cols + 1, -1, dtype=np.int_) + min_to = np.empty(n_cols + 1, dtype=np.float_) + # previous worker on alternating path + prev_col_for_col = np.empty(n_cols + 1, dtype=np.int_) + # whether col is in use + col_used = np.zeros(n_cols + 1, dtype=np.bool_) + + # assign row-th match + for row in range(n_rows): + col = n_cols + assignment_row[col] = row + # min reduced cost over edges from Z to worker w + min_to[:] = np.inf + prev_col_for_col[:] = -1 + + # runs at most row + 1 times + while assignment_row[col] != -1: + col_used[col] = True + row_cur = assignment_row[col] + delta = np.inf + col_next = -1 + + for col_i in range(n_cols): + if not col_used[col_i]: + dist = np.sum(np.square(pos1[row_cur, :] - pos2[col_i, :])) + if dist == np.inf: + raise ValueError( + "The distance between point is too large" + ) + + cur = ( + dist + - potentials_rows[row_cur] + - potentials_cols[col_i] + ) + if cur < min_to[col_i]: + min_to[col_i] = cur + prev_col_for_col[col_i] = col + + if min_to[col_i] < delta: + delta = min_to[col_i] + col_next = col_i + + # delta will always be non-negative, + # except possibly during the first time this loop runs + # if any entries of C[row] are negative + for col_i in range(n_cols + 1): + if col_used[col_i]: + potentials_rows[assignment_row[col_i]] += delta + potentials_cols[col_i] -= delta + else: + min_to[col_i] -= delta + col = col_next + + # update assignments along alternating path + while col != n_cols: + col_i = prev_col_for_col[col] + assignment_row[col] = assignment_row[col_i] + col = col_i + + with objmode(): + __compare_progress() + + # compute match from assignment + matches = np.empty(n_rows, dtype=np.int_) + for i in range(n_cols): + if assignment_row[i] != -1: + matches[assignment_row[i]] = i + + return matches + + +@njit +def analyze_point_matches( + pos1: np.ndarray, + pos2: np.ndarray, + matches: np.ndarray, + threshold: float = np.inf, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Given a matching found by `match_points`, it optionally applies a threshold + and splits the matched points from unmatched points in a friendlier way. + + E.g.:: + + >>> pos1 = np.array([[20, 10, 30, 40, 50]]).T + >>> pos2 = np.array([[5, 15, 25, 35, 100, 200]]).T + >>> matches = match_points(pos1, pos2) + >>> matches + array([1, 0, 2, 3, 4]) + >>> analyze_point_matches(pos1, pos2, matches) + (array([], dtype=int64), + array([[0, 1], + [1, 0], + [2, 2], + [3, 3], + [4, 4]], dtype=int64), + array([5], dtype=int64)) + >>> analyze_point_matches(pos1, pos2, matches, threshold=10) + (array([4], dtype=int64), + array([[0, 1], + [1, 0], + [2, 2], + [3, 3]], dtype=int64), + array([4, 5], dtype=int64)) + + Parameters + ---------- + pos1 : np.ndarray + Same as `match_points`. + pos2 : np.ndarray + Same as `match_points`. + matches : np.ndarray + The matches returned by `match_points`. + threshold : float, optional. Defaults to np.inf. + The threshold to use to remove bad matches. Any match pair whose + distance is greater than the threshold will be removed from the + matching and added to the missing_pos1 and missing_pos2 arrays. + + Returns + ------- + tuple : (np.ndarray, np.ndarray, np.ndarray) + missing_pos1: 1d array of all the indices of pos1 that found no match + in pos2 (sorted). + good_matches: 2d array with all the (pos1, pos2) indices that remained + in the matching. It's of size Rx2. It's sorted by the first column. + missing_pos2: 1d array of all the indices of pos2 that found no match + in pos1 (sorted). + """ + # indices and mask on indices + pos2_n = len(pos2) + pos2_i = np.arange(pos2_n) + pos2_mask = np.ones(pos2_n, dtype=np.bool_) + # those in pos2 who have a match in pos1 + pos2_mask[matches] = False + # all the pos2 that have no matches from pos1 + missing_pos2 = pos2_i[pos2_mask] + + # repackage matches so the first column is the pos1 idx and 2nd column is + # the corresponding pos2 index + matches_indices = np.stack((np.arange(len(pos1)), matches), axis=1) + + dist = np.sqrt(np.sum(np.square(pos1 - pos2[matches]), axis=1)) + too_large = dist >= threshold + bad_matches = matches_indices[too_large, :] + good_matches = matches_indices[np.logical_not(too_large), :] + + missing_pos1 = bad_matches[:, 0] + # more missing for pos2 for those above threshold + missing_pos2 = np.concatenate((missing_pos2, bad_matches[:, 1])) + missing_pos2 = np.sort(missing_pos2) + + return missing_pos1, good_matches, missing_pos2 diff --git a/pyproject.toml b/pyproject.toml index bc46b48..ed31c64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "configobj", "natsort", "nibabel >= 2.1.0", + "numba", "numpy", "pandas", "psutil", diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py new file mode 100644 index 0000000..e00ac6a --- /dev/null +++ b/tests/tests/test_cells/test_matches.py @@ -0,0 +1,73 @@ +from typing import List + +import numpy as np + +from brainglobe_utils.cells.cells import Cell, from_numpy_pos, match_cells + + +def as_cell(x: List[float]): + d = np.tile(np.asarray([x]).T, (1, 3)) + cells = from_numpy_pos(d, Cell.UNKNOWN) + return cells + + +def test_cell_matches_equal_size(): + a = as_cell([10, 20, 30, 40]) + b = as_cell([5, 15, 25, 35]) + a_, ab, b_ = match_cells(a, b) + assert not a_ + assert not b_ + assert [[0, 0], [1, 1], [2, 2], [3, 3]] == ab + + a = as_cell([20, 10, 30, 40]) + b = as_cell([5, 15, 25, 35]) + a_, ab, b_ = match_cells(a, b) + assert not a_ + assert not b_ + assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab + + a = as_cell([20, 10, 30, 40]) + b = as_cell([11, 22, 39, 42]) + a_, ab, b_ = match_cells(a, b) + assert not a_ + assert not b_ + assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab + + +def test_cell_matches_larger_other(): + a = as_cell([1, 12, 100, 80]) + b = as_cell([5, 15, 25, 35, 100]) + a_, ab, b_ = match_cells(a, b) + assert not a_ + assert b_ == [2] + assert [[0, 0], [1, 1], [2, 4], [3, 3]] == ab + + a = as_cell([20, 10, 30, 40]) + b = as_cell([11, 22, 39, 42, 41]) + a_, ab, b_ = match_cells(a, b) + assert not a_ + assert b_ == [3] + assert [[0, 1], [1, 0], [2, 2], [3, 4]] == ab + + +def test_cell_matches_larger_cells(): + a = as_cell([5, 15, 25, 35, 100]) + b = as_cell([1, 12, 100, 80]) + a_, ab, b_ = match_cells(a, b) + assert a_ == [2] + assert not b_ + assert [[0, 0], [1, 1], [3, 3], [4, 2]] == ab + + +def test_cell_matches_threshold(): + a = as_cell([10, 12, 100, 80]) + b = as_cell([0, 5, 15, 25, 35, 100]) + a_, ab, b_ = match_cells(a, b) + assert not a_ + assert b_ == [0, 3] + assert [[0, 1], [1, 2], [2, 5], [3, 4]] == ab + + a_, ab, b_ = match_cells(a, b, threshold=10) + assert a_ == [3] + assert b_ == [0, 3, 4] + assert [[0, 1], [1, 2], [2, 5]] == ab From ccd70d912d1e9f34a1b7d7f26c38c53e1827f26c Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Mon, 20 May 2024 20:59:57 -0400 Subject: [PATCH 02/13] Move threshold to matching step so we don't include points too far. --- brainglobe_utils/cells/cells.py | 63 +++++++++++++++--- pyproject.toml | 2 +- tests/tests/test_cells/test_cells.py | 17 +++++ tests/tests/test_cells/test_matches.py | 88 +++++++++++++++++++++++++- 4 files changed, 158 insertions(+), 12 deletions(-) diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 9de1778..75cdc28 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -499,9 +499,11 @@ def match_cells( progress = tqdm(desc="Matching cells", total=len(c1), unit="cells") __progress_update.updater = progress.update # for each index corresponding to c1, returns the index in c2 that matches - assignment = match_points(c1, c2) - progress.close() - __progress_update.updater = None + try: + assignment = match_points(c1, c2, threshold) + progress.close() + finally: + __progress_update.updater = None missing_c1, good_matches, missing_c2 = analyze_point_matches( c1, c2, assignment, threshold @@ -530,7 +532,9 @@ def __compare_progress(): @njit -def match_points(pos1: np.ndarray, pos2: np.ndarray) -> np.ndarray: +def match_points( + pos1: np.ndarray, pos2: np.ndarray, threshold: float = np.inf +) -> np.ndarray: """ Given two arrays, each a list of position. For each point in `pos1` it finds a point in `pos2` such that the distance between the assigned @@ -554,6 +558,16 @@ def match_points(pos1: np.ndarray, pos2: np.ndarray) -> np.ndarray: of dimensions (e.g. 3 for x, y, z). The relationship N <= M must be true. + threshold : float, optional. Defaults to np.inf. + The threshold to use to consider a pair a bad match. Any match pair + whose distance is greater or equal to the threshold will be considered + to be at great distance to each other. + + It'll still show up in the matching, but it will have the least + priority for a match because that match will not reduce the overall + cost across all points. + + Use `analyze_point_matches` subsequently to remove the "bad" matches. Returns ------- @@ -565,12 +579,35 @@ def match_points(pos1: np.ndarray, pos2: np.ndarray) -> np.ndarray: I.e. the match is (pos1[i], pos2[matches[i]]). """ # based on https://en.wikipedia.org/wiki/Hungarian_algorithm + pos1 = pos1.astype(np.float64) + pos2 = pos2.astype(np.float64) + + if len(pos1.shape) != 2 or len(pos2.shape) != 2: + raise ValueError("The input arrays must have exactly 2 dimensions") + n_rows = pos1.shape[0] n_cols = pos2.shape[0] if n_rows > n_cols: raise ValueError( "The length of pos1 must be less than or equal to length of pos2" ) + if pos1.shape[1] != pos2.shape[1]: + raise ValueError("The two inputs have different number of columns") + + inf_dist = 0 + have_threshold = threshold != np.inf + # If we use a threshold, find the largest enclosing (hyper) cube and use + # the distance between two opposing corners as the maximum distance we + # can ever see. Use that as dist of points further than threshold + if have_threshold: + # for each col, find the range of points and pick greatest col + largest_side = 0 + for i in range(pos1.shape[1]): + bottom = min(np.min(pos1[:, i]), np.min(pos2[:, i])) + top = max(np.max(pos1[:, i]), np.max(pos2[:, i])) + largest_side = max(largest_side, top - bottom) + # make cube using the largest col range + inf_dist = math.sqrt(pos1.shape[1]) * (largest_side + 1) potentials_rows = np.zeros(n_rows) potentials_cols = np.zeros(n_cols + 1) @@ -598,11 +635,16 @@ def match_points(pos1: np.ndarray, pos2: np.ndarray) -> np.ndarray: for col_i in range(n_cols): if not col_used[col_i]: - dist = np.sum(np.square(pos1[row_cur, :] - pos2[col_i, :])) + # use sqrt to match threshold which is in actual distance + dist = np.sqrt( + np.sum(np.square(pos1[row_cur, :] - pos2[col_i, :])) + ) if dist == np.inf: raise ValueError( "The distance between point is too large" ) + if have_threshold and dist >= threshold: + dist = inf_dist cur = ( dist @@ -693,6 +735,9 @@ def analyze_point_matches( distance is greater than the threshold will be removed from the matching and added to the missing_pos1 and missing_pos2 arrays. + To get a best global optimum, use the same threshold you used in + `match_points`. + Returns ------- tuple : (np.ndarray, np.ndarray, np.ndarray) @@ -706,17 +751,17 @@ def analyze_point_matches( # indices and mask on indices pos2_n = len(pos2) pos2_i = np.arange(pos2_n) - pos2_mask = np.ones(pos2_n, dtype=np.bool_) + pos2_unmatched = np.ones(pos2_n, dtype=np.bool_) # those in pos2 who have a match in pos1 - pos2_mask[matches] = False + pos2_unmatched[matches] = False # all the pos2 that have no matches from pos1 - missing_pos2 = pos2_i[pos2_mask] + missing_pos2 = pos2_i[pos2_unmatched] # repackage matches so the first column is the pos1 idx and 2nd column is # the corresponding pos2 index matches_indices = np.stack((np.arange(len(pos1)), matches), axis=1) - dist = np.sqrt(np.sum(np.square(pos1 - pos2[matches]), axis=1)) + dist = np.sqrt(np.sum(np.square(pos1 - pos2[matches, :]), axis=1)) too_large = dist >= threshold bad_matches = matches_indices[too_large, :] good_matches = matches_indices[np.logical_not(too_large), :] diff --git a/pyproject.toml b/pyproject.toml index ed31c64..6d5be4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,7 +136,7 @@ python = extras = dev commands = - pytest -v --color=yes --cov=brainglobe_utils --cov-report=xml + NUMBA_DISABLE_JIT=1 pytest -v --color=yes --cov=brainglobe_utils --cov-report=xml passenv = CI GITHUB_ACTIONS diff --git a/tests/tests/test_cells/test_cells.py b/tests/tests/test_cells/test_cells.py index 05f3c33..962d875 100644 --- a/tests/tests/test_cells/test_cells.py +++ b/tests/tests/test_cells/test_cells.py @@ -214,3 +214,20 @@ def test_conversion_typed_and_untyped_cell(): assert cells.UntypedCell.from_cell(typed_cell) == untyped_cell assert untyped_cell.to_cell() == typed_cell + + +def test_cells_to_np_cell_type(): + items = [ + cells.Cell((0, 1, 2), cells.Cell.UNKNOWN), + cells.Cell((3, 4, 5), cells.Cell.CELL), + ] + + assert np.array_equal( + cells.to_numpy_pos(items), + [[0, 1, 2], [3, 4, 5]], + ) + + assert np.array_equal( + cells.to_numpy_pos(items, cells.Cell.CELL), + [[3, 4, 5]], + ) diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py index e00ac6a..8543f8c 100644 --- a/tests/tests/test_cells/test_matches.py +++ b/tests/tests/test_cells/test_matches.py @@ -1,8 +1,16 @@ +import math from typing import List import numpy as np +import pytest -from brainglobe_utils.cells.cells import Cell, from_numpy_pos, match_cells +import brainglobe_utils.cells.cells as cell_utils +from brainglobe_utils.cells.cells import ( + Cell, + from_numpy_pos, + match_cells, + match_points, +) def as_cell(x: List[float]): @@ -67,7 +75,83 @@ def test_cell_matches_threshold(): assert b_ == [0, 3] assert [[0, 1], [1, 2], [2, 5], [3, 4]] == ab - a_, ab, b_ = match_cells(a, b, threshold=10) + a_, ab, b_ = match_cells(a, b, threshold=math.sqrt(3) * 11) assert a_ == [3] assert b_ == [0, 3, 4] assert [[0, 1], [1, 2], [2, 5]] == ab + + +def test_global_optimum_with_threshold_original_pr(): + cells1 = [ + Cell((0, 0, 0), Cell.UNKNOWN), + Cell((12, 0, 0), Cell.UNKNOWN), + ] + cells2 = [ + Cell((10, 0, 0), Cell.UNKNOWN), + Cell((22, 0, 0), Cell.UNKNOWN), + ] + + # without threshold, the global optimum pars points (0, 10), (12, 22) at a + # global cost of 20. The other pairing would have cost of 24 + missing_c1, good_matches, missing_c2 = match_cells( + cells1, cells2, threshold=np.inf + ) + assert not missing_c1 + assert not missing_c2 + assert good_matches == [[0, 0], [1, 1]] + + # with threshold, the previous pairing should not be considered good. + # Instead, only (12, 10) is a good match. So while total cost is 24, + # we only care about the cost of 2 during the matching algorithm + missing_c1, good_matches, missing_c2 = match_cells( + cells1, cells2, threshold=5 + ) + # before we added the threshold to match_points, the following applies + # assert missing_c1 == [0, 1] + # assert missing_c2 == [0, 1] + # assert not good_matches + # with threshold in match_points, this is true - as wanted + assert missing_c1 == [ + 0, + ] + assert missing_c2 == [ + 1, + ] + assert good_matches == [[1, 0]] + + +def test_rows_greater_than_cols(): + with pytest.raises(ValueError): + match_points(np.zeros((5, 3)), np.zeros((4, 3))) + + +def test_unequal_inputs_shape(): + with pytest.raises(ValueError): + match_points(np.zeros((5, 3)), np.zeros((5, 2))) + + +def test_bad_input_shape(): + with pytest.raises(ValueError): + match_points(np.zeros(5), np.zeros(5)) + + with pytest.raises(ValueError): + match_points(np.zeros((5, 4, 6)), np.zeros((5, 4, 6))) + + +def test_progress_already_running(): + a = as_cell([10, 12]) + b = as_cell([10, 12]) + cell_utils.__progress_update.updater = 1 + + try: + with pytest.raises(TypeError): + match_cells(a, b) + finally: + cell_utils.__progress_update.updater = None + + +def test_distance_too_large(): + a = np.array([[1, 2, 3]]) + b = np.array([[1, 2, np.inf]]) + with pytest.raises(ValueError): + match_points(a, b) From 83bb74550e7ad20e77361cddc7bbf225cb5495ca Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Mon, 20 May 2024 22:17:29 -0400 Subject: [PATCH 03/13] We don't need max. We can just set to threshold. --- brainglobe_utils/cells/cells.py | 21 +++++---------------- tests/tests/test_cells/test_matches.py | 7 ++++++- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 75cdc28..58616ba 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -581,8 +581,10 @@ def match_points( # based on https://en.wikipedia.org/wiki/Hungarian_algorithm pos1 = pos1.astype(np.float64) pos2 = pos2.astype(np.float64) + # numba pre-checks that arrays are at least 2-dims. Us checking would be + # too late and never invoked - if len(pos1.shape) != 2 or len(pos2.shape) != 2: + if pos1.ndim != 2 or pos2.ndim != 2: raise ValueError("The input arrays must have exactly 2 dimensions") n_rows = pos1.shape[0] @@ -594,20 +596,7 @@ def match_points( if pos1.shape[1] != pos2.shape[1]: raise ValueError("The two inputs have different number of columns") - inf_dist = 0 have_threshold = threshold != np.inf - # If we use a threshold, find the largest enclosing (hyper) cube and use - # the distance between two opposing corners as the maximum distance we - # can ever see. Use that as dist of points further than threshold - if have_threshold: - # for each col, find the range of points and pick greatest col - largest_side = 0 - for i in range(pos1.shape[1]): - bottom = min(np.min(pos1[:, i]), np.min(pos2[:, i])) - top = max(np.max(pos1[:, i]), np.max(pos2[:, i])) - largest_side = max(largest_side, top - bottom) - # make cube using the largest col range - inf_dist = math.sqrt(pos1.shape[1]) * (largest_side + 1) potentials_rows = np.zeros(n_rows) potentials_cols = np.zeros(n_cols + 1) @@ -643,8 +632,8 @@ def match_points( raise ValueError( "The distance between point is too large" ) - if have_threshold and dist >= threshold: - dist = inf_dist + if have_threshold and dist > threshold: + dist = threshold cur = ( dist diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py index 8543f8c..a7a1aa3 100644 --- a/tests/tests/test_cells/test_matches.py +++ b/tests/tests/test_cells/test_matches.py @@ -131,7 +131,12 @@ def test_unequal_inputs_shape(): def test_bad_input_shape(): - with pytest.raises(ValueError): + # we want to check that a 1-dim array is not accepted. But, numba checks + # the inputs for at least 2-dims because it knows we access the 2dn dim. + # So we have no chance to raise an error ourself. So check numba's error + import numba.core.errors + + with pytest.raises(numba.core.errors.TypingError): match_points(np.zeros(5), np.zeros(5)) with pytest.raises(ValueError): From 48617f67116359177c216a161da88cad891fff7d Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Mon, 20 May 2024 22:30:54 -0400 Subject: [PATCH 04/13] Pass numba disable in env. --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6d5be4c..e7a83cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,9 +136,10 @@ python = extras = dev commands = - NUMBA_DISABLE_JIT=1 pytest -v --color=yes --cov=brainglobe_utils --cov-report=xml + pytest -v --color=yes --cov=brainglobe_utils --cov-report=xml passenv = CI + NUMBA_DISABLE_JIT GITHUB_ACTIONS DISPLAY XAUTHORITY From 1bf364a21b7955cfe65f29e0514df265c947d667 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Wed, 22 May 2024 20:11:43 -0400 Subject: [PATCH 05/13] We need to reset the used cols at each iteration. --- brainglobe_utils/cells/cells.py | 1 + 1 file changed, 1 insertion(+) diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 58616ba..0e04a48 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -614,6 +614,7 @@ def match_points( # min reduced cost over edges from Z to worker w min_to[:] = np.inf prev_col_for_col[:] = -1 + col_used[:] = False # runs at most row + 1 times while assignment_row[col] != -1: From a50f4aa793507801897d13ecee0d24ba1091ec17 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Wed, 22 May 2024 23:21:56 -0400 Subject: [PATCH 06/13] Add support for pre-extracting zero distance pairs. --- brainglobe_utils/cells/cells.py | 324 +++++++++++++++++++++---- tests/tests/test_cells/test_matches.py | 94 ++++--- 2 files changed, 336 insertions(+), 82 deletions(-) diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 0e04a48..606af67 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -434,7 +434,10 @@ def from_numpy_pos(pos: np.ndarray, cell_type: int) -> List[Cell]: def match_cells( - cells: List[Cell], other: List[Cell], threshold: float = np.inf + cells: List[Cell], + other: List[Cell], + threshold: float = np.inf, + pre_match: bool = True, ) -> Tuple[List[int], List[Tuple[int, int]], List[int]]: """ Given two lists of cells. It finds a pairing of cells from `cells` and @@ -470,6 +473,12 @@ def match_cells( The threshold to use to remove bad matches. Any match pair whose distance is greater than the threshold will be exluded from the matching. + pre_match : bool, optional. Defaults to True. + If True, we will (interenally) first efficiently find all the pairs of + `cells` and `others` which are each at the same position in space. Then + we run the optimization to find the best matching on the remaining. + This will significantly speed up the matching, if there are pairs of + cells on top of each other in each set. Returns ------- @@ -500,7 +509,7 @@ def match_cells( __progress_update.updater = progress.update # for each index corresponding to c1, returns the index in c2 that matches try: - assignment = match_points(c1, c2, threshold) + assignment = match_points(c1, c2, threshold, pre_match) progress.close() finally: __progress_update.updater = None @@ -526,75 +535,189 @@ def match_cells( __progress_update.updater = None -def __compare_progress(): +def __compare_progress(n: int = 1) -> None: + """Updates the progress bar by `n`, if there's one set.""" if __progress_update.updater is not None: - __progress_update.updater() + __progress_update.updater(n) @njit -def match_points( - pos1: np.ndarray, pos2: np.ndarray, threshold: float = np.inf -) -> np.ndarray: +def _find_pairs_sorted( + pos1: np.ndarray, pos2: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ - Given two arrays, each a list of position. For each point in `pos1` it - finds a point in `pos2` such that the distance between the assigned - matches across all `pos1` is minimized. + Given two sorted arrays, returns all the pairs in the arrays (without + replacement) that are at a `np.isclose` distance of each other. - E.g.:: + This is computed in O(N) time. - >>> pos1 = np.array([[20, 10, 30, 40]]).T - >>> pos2 = np.array([[5, 15, 25, 35, 50]]).T - >>> matches = match_points(pos1, pos2) - >>> matches - array([1, 0, 2, 3]) + Parameters + ---------- + pos1 : Sorted (1st axis) np.ndarray of shape `MxK`. + pos2 : Sorted (1st axis) np.ndarray of shape `NxK`. + + Returns + ------- + tuple : + used_mask_1 : Bool np.ndarray of size `M`. + It's True at indices of the rows in `pos1` used up in the pairing. + used_mask_2 : Bool np.ndarray of size `N`. + It's True at indices of the rows in `pos2` used up in the pairing. + paired_indices: np.ndarray of shape `Rx2`. + Each row is a pair of indices to `pos1` and `pos2`, respectively. + Indicating a pair that is `close` to each other. + """ + # mask of pos1/pos2 for the elements used in a identical par + used_mask_1 = np.zeros(pos1.shape[0], dtype=np.bool_) + used_mask_2 = np.zeros(pos2.shape[0], dtype=np.bool_) + n_cols = pos1.shape[1] + + # the pos1/pos2 indices for each pair - at most this many pairs + max_n = min(pos1.shape[0], pos2.shape[0]) + paired_indices = np.zeros((max_n, 2), dtype=np.int64) + + # how many pairs found + used_n = 0 + # next index to check for pair for pos1/2 + pos1_i = 0 + pos2_i = 0 + + # do this in O(N), until we reach end of either array + while pos1_i < max_n and pos2_i < max_n: + # are the two points the same + same = True + for i in range(n_cols): + same = same and np.isclose(pos1[pos1_i, i], pos2[pos2_i, i]) + + # they match + if same: + used_mask_1[pos1_i] = True + used_mask_2[pos2_i] = True + paired_indices[used_n, 0] = pos1_i + paired_indices[used_n, 1] = pos2_i + used_n += 1 + pos1_i += 1 + pos2_i += 1 + else: + # the points are not the same in at least one dim, which is less? + one_is_less = True + for i in range(n_cols): + # for dims until this one (if any), they are the same + if pos1[pos1_i, i] < pos2[pos2_i, i]: # first is less + break + elif pos1[pos1_i, i] > pos2[pos2_i, i]: # second is less + one_is_less = False + break + # they were the same in this axis as well + else: # pragma: no cover + assert False, "at least in one dim it should be different" + + if one_is_less: + # first is less than second, advance first by one + pos1_i += 1 + else: + # second is less than first, advance second by one + pos2_i += 1 + + return used_mask_1, used_mask_2, paired_indices[:used_n, :] + + +def _find_identical_points( + pos1: np.ndarray, pos2: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Given two arrays, returns the set of pairs of points in the arrays + (without replacement) that are at a `np.isclose` distance of each other. + + This is computed in O(NlogN) time, dominated by sorting (internally). + + Parameters + ---------- + pos1 : np.ndarray of shape `MxK`. + pos2 : np.ndarray of shape `NxK`. + + Returns + ------- + tuple : + unpaired1_indices : np.ndarray of size `M - R`. + Array of indices of `pos1` that were not used in any pairs + unpaired2_indices : np.ndarray of size `N - R`. + Array of indices of `pos2` that were not used in any pairs. + paired_indices: np.ndarray of shape `Rx2`. + Each row is a pair of indices to `pos1` and `pos2`, respectively. + Indicating a pair that is `close` to each other. + """ + # sort pos1 and pos2 rows, ordered by columns 1..N. + # lexsort uses rows as keys and sorts the keys from last to first. + # So flip order of rows and then transpose + indices = np.lexsort(np.flip(pos1, axis=1).transpose()) + pos1_sorted = pos1[indices, :] + # original pos1 indices of the sorted elements + orig1_indices = np.arange(len(pos1), dtype=np.int64)[indices] + + indices = np.lexsort(np.flip(pos2, axis=1).transpose()) + pos2_sorted = pos2[indices, :] + orig2_indices = np.arange(len(pos2), dtype=np.int64)[indices] + + # get the zero distance pairs + used_mask_1, used_mask_2, paired_indices = _find_pairs_sorted( + pos1_sorted, pos2_sorted + ) + + # convert the indices back to the original unsorted indices + unpaired1_indices = orig1_indices[np.logical_not(used_mask_1)] + unpaired2_indices = orig2_indices[np.logical_not(used_mask_2)] + paired_indices[:, 0] = orig1_indices[paired_indices[:, 0]] + paired_indices[:, 1] = orig2_indices[paired_indices[:, 1]] + + return unpaired1_indices, unpaired2_indices, paired_indices + + +@njit +def _optimize_pairs( + pos1: np.ndarray, + pos2: np.ndarray, + threshold: float, +) -> np.ndarray: + """ + Implements `match_points` using + https://en.wikipedia.org/wiki/Hungarian_algorithm. Parameters ---------- pos1 : np.ndarray - 2D array of NxK. Where N is number of positions and K is the number - of dimensions (e.g. 3 for x, y, z). + 2D array of NxK. pos2 : np.ndarray - 2D array of MxK. Where M is number of positions and K is the number - of dimensions (e.g. 3 for x, y, z). + 2D array of MxK. - The relationship N <= M must be true. + The relationship N <= M must be true and K must be the same for both. threshold : float, optional. Defaults to np.inf. The threshold to use to consider a pair a bad match. Any match pair whose distance is greater or equal to the threshold will be considered - to be at great distance to each other. + to be at threshold distance (i.e. max distance). It'll still show up in the matching, but it will have the least priority for a match because that match will not reduce the overall cost across all points. - Use `analyze_point_matches` subsequently to remove the "bad" matches. - Returns ------- matches : np.ndarray - 1D array of length N. Each index i in matches corresponds - to index i in `pos1`. The value of index i in matches is the index - j in pos2 that is the best match for that pos1. - - I.e. the match is (pos1[i], pos2[matches[i]]). + 1D array of length N, where index i in matches corresponds + to index i in `pos1` and its value is the index in pos2 + of its best match. """ - # based on https://en.wikipedia.org/wiki/Hungarian_algorithm - pos1 = pos1.astype(np.float64) - pos2 = pos2.astype(np.float64) - # numba pre-checks that arrays are at least 2-dims. Us checking would be - # too late and never invoked - - if pos1.ndim != 2 or pos2.ndim != 2: - raise ValueError("The input arrays must have exactly 2 dimensions") - + # we don't check for boundary conditions, just assert because it should be + # checked by caller (match_points) n_rows = pos1.shape[0] n_cols = pos2.shape[0] - if n_rows > n_cols: - raise ValueError( - "The length of pos1 must be less than or equal to length of pos2" - ) - if pos1.shape[1] != pos2.shape[1]: - raise ValueError("The two inputs have different number of columns") + assert len(pos1.shape) == 2 + assert len(pos2.shape) == 2 + assert pos1.shape[1] == pos2.shape[1] + assert n_rows <= n_cols + + pos1 = pos1.astype(np.float64) + pos2 = pos2.astype(np.float64) have_threshold = threshold != np.inf @@ -626,9 +749,10 @@ def match_points( for col_i in range(n_cols): if not col_used[col_i]: # use sqrt to match threshold which is in actual distance - dist = np.sqrt( - np.sum(np.square(pos1[row_cur, :] - pos2[col_i, :])) - ) + dist = 0.0 + for i in range(pos1.shape[1]): + dist += math.pow(pos1[row_cur, i] - pos2[col_i, i], 2) + dist = math.sqrt(dist) if dist == np.inf: raise ValueError( "The distance between point is too large" @@ -670,7 +794,7 @@ def match_points( __compare_progress() # compute match from assignment - matches = np.empty(n_rows, dtype=np.int_) + matches = np.empty(n_rows, dtype=np.int64) for i in range(n_cols): if assignment_row[i] != -1: matches[assignment_row[i]] = i @@ -678,6 +802,110 @@ def match_points( return matches +def match_points( + pos1: np.ndarray, + pos2: np.ndarray, + threshold: float = np.inf, + pre_match: bool = True, +) -> np.ndarray: + """ + Given two arrays, each a list of position. For each point in `pos1` it + finds a point in `pos2` such that the distance between the assigned + matches across all `pos1` is minimized. + + E.g.:: + + >>> pos1 = np.array([[20, 10, 30, 40]]).T + >>> pos2 = np.array([[5, 15, 25, 35, 50]]).T + >>> matches = match_points(pos1, pos2) + >>> matches + array([1, 0, 2, 3]) + + Parameters + ---------- + pos1 : np.ndarray + 2D array of NxK. Where N is number of positions and K is the number + of dimensions (e.g. 3 for x, y, z). + pos2 : np.ndarray + 2D array of MxK. Where M is number of positions and K is the number + of dimensions (e.g. 3 for x, y, z). + + The relationship N <= M must be true. + threshold : float, optional. Defaults to np.inf. + The threshold to use to consider a pair a bad match. Any match pair + whose distance is greater or equal to the threshold will be considered + to be at threshold distance (i.e. the max distance). + + It'll still show up in the matching, but it will have the least + priority for a match because that match will not reduce the overall + cost across all points. + + Use `analyze_point_matches` with the same threshold subsequently to + remove the "bad" matches. + pre_match : bool, optional. Defaults to True. + If True, we will (interenally) first efficiently find all the pairs of + `pos1` and `pos2` which are each at the same position in space. Then + we run the optimization to find the best matching on the remaining. + + If True, it'll significantly speed up the matching, if there are pairs + of points on top of each other across the input lists. + + Returns + ------- + matches : np.ndarray + 1D array of length N. Each index i in matches corresponds + to index i in `pos1`. The value of index i in matches is the index + j in pos2 that is the best match for that pos1. + + I.e. the match is (pos1[i], pos2[matches[i]]). + """ + if len(pos1.shape) != 2 or len(pos2.shape) != 2: + raise ValueError("The input arrays must have exactly 2 dimensions") + + n_rows = pos1.shape[0] + n_cols = pos2.shape[0] + if n_rows > n_cols: + raise ValueError( + "The length of pos1 must be less than or equal to length of pos2" + ) + if pos1.shape[1] != pos2.shape[1]: + raise ValueError("The two inputs have different number of columns") + + if not pre_match: + # do optimization on full inputs + return _optimize_pairs(pos1, pos2, threshold) + + # extract the indices of zero-pairs and remaining points + unpaired1_indices, unpaired2_indices, paired_indices = ( + _find_identical_points(pos1, pos2) + ) + # the number of zero pairs found are done! + __compare_progress(len(paired_indices)) + + # if everything was zero pairs, we're done! + if not len(unpaired1_indices): + # sort by pos1 and then return the corresponding pos2 indices matching + # those pos1 points + pos1_sorted_indices = np.argsort(paired_indices[:, 0]) + return paired_indices[pos1_sorted_indices, 1] + + # extract remaining pos1/post2 and run optimization + pos1 = pos1[unpaired1_indices] + pos2 = pos2[unpaired2_indices] + n_rows = pos1.shape[0] + + matches = _optimize_pairs(pos1, pos2, threshold) + + # map extracted + full_matches = np.empty(n_rows + len(paired_indices), dtype=np.int64) + # set pos1 optimized matches to their pos2 indices + full_matches[unpaired1_indices] = unpaired2_indices[matches] + # set the zero pairs pos1 to corresponding pos2 match indices + full_matches[paired_indices[:, 0]] = paired_indices[:, 1] + + return full_matches + + @njit def analyze_point_matches( pos1: np.ndarray, diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py index a7a1aa3..f0ba023 100644 --- a/tests/tests/test_cells/test_matches.py +++ b/tests/tests/test_cells/test_matches.py @@ -19,69 +19,76 @@ def as_cell(x: List[float]): return cells -def test_cell_matches_equal_size(): +@pytest.mark.parametrize("pre_match", [True, False]) +def test_cell_matches_equal_size(pre_match): a = as_cell([10, 20, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert not b_ assert [[0, 0], [1, 1], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42]) - a_, ab, b_ = match_cells(a, b) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab -def test_cell_matches_larger_other(): +@pytest.mark.parametrize("pre_match", [True, False]) +def test_cell_matches_larger_other(pre_match): a = as_cell([1, 12, 100, 80]) b = as_cell([5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert b_ == [2] assert [[0, 0], [1, 1], [2, 4], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42, 41]) - a_, ab, b_ = match_cells(a, b) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert b_ == [3] assert [[0, 1], [1, 0], [2, 2], [3, 4]] == ab -def test_cell_matches_larger_cells(): +@pytest.mark.parametrize("pre_match", [True, False]) +def test_cell_matches_larger_cells(pre_match): a = as_cell([5, 15, 25, 35, 100]) b = as_cell([1, 12, 100, 80]) - a_, ab, b_ = match_cells(a, b) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert a_ == [2] assert not b_ assert [[0, 0], [1, 1], [3, 3], [4, 2]] == ab -def test_cell_matches_threshold(): +@pytest.mark.parametrize("pre_match", [True, False]) +def test_cell_matches_threshold(pre_match): a = as_cell([10, 12, 100, 80]) b = as_cell([0, 5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert b_ == [0, 3] assert [[0, 1], [1, 2], [2, 5], [3, 4]] == ab - a_, ab, b_ = match_cells(a, b, threshold=math.sqrt(3) * 11) + a_, ab, b_ = match_cells( + a, b, threshold=math.sqrt(3) * 11, pre_match=pre_match + ) assert a_ == [3] assert b_ == [0, 3, 4] assert [[0, 1], [1, 2], [2, 5]] == ab -def test_global_optimum_with_threshold_original_pr(): +@pytest.mark.parametrize("pre_match", [True, False]) +def test_global_optimum_with_threshold_original_pr(pre_match): cells1 = [ Cell((0, 0, 0), Cell.UNKNOWN), Cell((12, 0, 0), Cell.UNKNOWN), @@ -94,7 +101,7 @@ def test_global_optimum_with_threshold_original_pr(): # without threshold, the global optimum pars points (0, 10), (12, 22) at a # global cost of 20. The other pairing would have cost of 24 missing_c1, good_matches, missing_c2 = match_cells( - cells1, cells2, threshold=np.inf + cells1, cells2, threshold=np.inf, pre_match=pre_match ) assert not missing_c1 assert not missing_c2 @@ -104,7 +111,7 @@ def test_global_optimum_with_threshold_original_pr(): # Instead, only (12, 10) is a good match. So while total cost is 24, # we only care about the cost of 2 during the matching algorithm missing_c1, good_matches, missing_c2 = match_cells( - cells1, cells2, threshold=5 + cells1, cells2, threshold=5, pre_match=pre_match ) # before we added the threshold to match_points, the following applies # assert missing_c1 == [0, 1] @@ -120,43 +127,62 @@ def test_global_optimum_with_threshold_original_pr(): assert good_matches == [[1, 0]] -def test_rows_greater_than_cols(): +@pytest.mark.parametrize("pre_match", [True, False]) +def test_rows_greater_than_cols(pre_match): with pytest.raises(ValueError): - match_points(np.zeros((5, 3)), np.zeros((4, 3))) + match_points(np.zeros((5, 3)), np.zeros((4, 3)), pre_match=pre_match) -def test_unequal_inputs_shape(): +@pytest.mark.parametrize("pre_match", [True, False]) +def test_unequal_inputs_shape(pre_match): with pytest.raises(ValueError): - match_points(np.zeros((5, 3)), np.zeros((5, 2))) - + match_points(np.zeros((5, 3)), np.zeros((5, 2)), pre_match=pre_match) -def test_bad_input_shape(): - # we want to check that a 1-dim array is not accepted. But, numba checks - # the inputs for at least 2-dims because it knows we access the 2dn dim. - # So we have no chance to raise an error ourself. So check numba's error - import numba.core.errors - with pytest.raises(numba.core.errors.TypingError): - match_points(np.zeros(5), np.zeros(5)) +@pytest.mark.parametrize("pre_match", [True, False]) +def test_bad_input_shape(pre_match): + # has to be 2 dims + with pytest.raises(ValueError): + match_points(np.zeros(5), np.zeros(5), pre_match=pre_match) with pytest.raises(ValueError): - match_points(np.zeros((5, 4, 6)), np.zeros((5, 4, 6))) + match_points( + np.zeros((5, 4, 6)), np.zeros((5, 4, 6)), pre_match=pre_match + ) -def test_progress_already_running(): +@pytest.mark.parametrize("pre_match", [True, False]) +def test_progress_already_running(pre_match): a = as_cell([10, 12]) b = as_cell([10, 12]) cell_utils.__progress_update.updater = 1 try: with pytest.raises(TypeError): - match_cells(a, b) + match_cells(a, b, pre_match=pre_match) finally: cell_utils.__progress_update.updater = None -def test_distance_too_large(): - a = np.array([[1, 2, 3]]) - b = np.array([[1, 2, np.inf]]) +@pytest.mark.parametrize("pre_match", [True, False]) +def test_distance_too_large(pre_match): + a = np.array([[1, 2, 3]]).T + b = np.array([[1, 2, np.inf]]).T with pytest.raises(ValueError): - match_points(a, b) + match_points(a, b, pre_match=pre_match) + + +@pytest.mark.parametrize("pre_match", [True, False]) +def test_contains_identical_points(pre_match): + a = np.array([[1, 10], [5, 7], [22, 12]]) + b = np.array([[5, 7], [7, 1], [21, 10]]) + matching = match_points(a, b, pre_match=pre_match) + assert np.array_equal(matching, [1, 0, 2]) + + +@pytest.mark.parametrize("pre_match", [True, False]) +def test_only_identical_points(pre_match): + a = np.array([[1, 2, 3]]).T + b = np.array([[2, 3, 5, 1]]).T + matching = match_points(a, b, pre_match=pre_match) + assert np.array_equal(matching, [3, 0, 1]) From e1ca9726cbadfb373d6b782c0abfcbb50c9e37ca Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Thu, 23 May 2024 00:51:38 -0400 Subject: [PATCH 07/13] Add option to use scipy optimization. --- brainglobe_utils/cells/cells.py | 84 ++++++++++++++- tests/tests/test_cells/test_matches.py | 137 +++++++++++++++++-------- 2 files changed, 174 insertions(+), 47 deletions(-) diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 606af67..64be3ab 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -23,6 +23,7 @@ import numpy as np from numba import njit, objmode +from scipy.optimize import linear_sum_assignment from tqdm import tqdm @@ -438,6 +439,7 @@ def match_cells( other: List[Cell], threshold: float = np.inf, pre_match: bool = True, + use_scipy: bool = False, ) -> Tuple[List[int], List[Tuple[int, int]], List[int]]: """ Given two lists of cells. It finds a pairing of cells from `cells` and @@ -477,8 +479,14 @@ def match_cells( If True, we will (interenally) first efficiently find all the pairs of `cells` and `others` which are each at the same position in space. Then we run the optimization to find the best matching on the remaining. + This will significantly speed up the matching, if there are pairs of cells on top of each other in each set. + use_scipy : bool, optional. Defaults to False. + Whether to use scipy `linear_sum_assignment` to find the optimal + matching. Otherwise, we use our own implementation. + + See `match_points` for consideration details. Returns ------- @@ -509,7 +517,7 @@ def match_cells( __progress_update.updater = progress.update # for each index corresponding to c1, returns the index in c2 that matches try: - assignment = match_points(c1, c2, threshold, pre_match) + assignment = match_points(c1, c2, threshold, pre_match, use_scipy) progress.close() finally: __progress_update.updater = None @@ -802,11 +810,62 @@ def _optimize_pairs( return matches +def _optimize_pairs_scipy( + pos1: np.ndarray, + pos2: np.ndarray, +) -> np.ndarray: + """ + Implements `match_points` using scipy's `linear_sum_assignment`. + + The function has a memory cost of N*M*k*8 bytes. When `K=3` and `N=M`, + approximately 1GB is required for `N=M=6689` points. For `N=M=26.8k` + points, we need 16GB. For `N=M=75.7k` points, we need 128GB. + + Parameters + ---------- + pos1 : np.ndarray + 2D array of NxK. + pos2 : np.ndarray + 2D array of MxK. + + The relationship N <= M must be true and K must be the same for both. + + Returns + ------- + matches : np.ndarray + 1D array of length N, where index i in matches corresponds + to index i in `pos1` and its value is the index in pos2 + of its best match. + """ + # we don't check for boundary conditions, just assert because it should be + # checked by caller (match_points) + n_rows = pos1.shape[0] + n_cols = pos2.shape[0] + assert len(pos1.shape) == 2 + assert len(pos2.shape) == 2 + assert pos1.shape[1] == pos2.shape[1] + assert n_rows <= n_cols + + # Mxk -> M1K + pos1 = pos1[:, np.newaxis, :] + # Nxk -> 1NK + pos2 = pos2[np.newaxis, :, :] + # dist is MNK + dist = pos1 - pos2 + # cost is MN + cost = np.sqrt(np.sum(np.square(dist), axis=2)) + # result is sorted by rows + rows, cols = linear_sum_assignment(cost) + # M <= N, so cols and rows is size M + return cols + + def match_points( pos1: np.ndarray, pos2: np.ndarray, threshold: float = np.inf, pre_match: bool = True, + use_scipy: bool = False, ) -> np.ndarray: """ Given two arrays, each a list of position. For each point in `pos1` it @@ -849,6 +908,19 @@ def match_points( If True, it'll significantly speed up the matching, if there are pairs of points on top of each other across the input lists. + use_scipy : bool, optional. Defaults to False. + Whether to use scipy `linear_sum_assignment` to find the optimal + matching. Otherwise, we use our own implementation. + + Our implementation is very memory efficient. Using scipy, we have e.g. + a memory requirement of approximately 1GB for 6.7k points (`N=M=6.7k`), + 16GB for 26.8k points, and 128GB for 75.7k points. + + If `pre_match` is used, and it eliminates many zero-distance points, + using numpy is feasable. Otherwise, use our implementation. + + Note: When using scipy, we *don't* take the threshold into account + until after the matching is complete. Returns ------- @@ -873,7 +945,10 @@ def match_points( if not pre_match: # do optimization on full inputs - return _optimize_pairs(pos1, pos2, threshold) + if use_scipy: + return _optimize_pairs_scipy(pos1, pos2) + else: + return _optimize_pairs(pos1, pos2, threshold) # extract the indices of zero-pairs and remaining points unpaired1_indices, unpaired2_indices, paired_indices = ( @@ -894,7 +969,10 @@ def match_points( pos2 = pos2[unpaired2_indices] n_rows = pos1.shape[0] - matches = _optimize_pairs(pos1, pos2, threshold) + if use_scipy: + matches = _optimize_pairs_scipy(pos1, pos2) + else: + matches = _optimize_pairs(pos1, pos2, threshold) # map extracted full_matches = np.empty(n_rows + len(paired_indices), dtype=np.int64) diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py index f0ba023..add83bd 100644 --- a/tests/tests/test_cells/test_matches.py +++ b/tests/tests/test_cells/test_matches.py @@ -19,76 +19,85 @@ def as_cell(x: List[float]): return cells +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_equal_size(pre_match): +def test_cell_matches_equal_size(pre_match, use_scipy): a = as_cell([10, 20, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert not b_ assert [[0, 0], [1, 1], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_larger_other(pre_match): +def test_cell_matches_larger_other(pre_match, use_scipy): a = as_cell([1, 12, 100, 80]) b = as_cell([5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert b_ == [2] assert [[0, 0], [1, 1], [2, 4], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42, 41]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert b_ == [3] assert [[0, 1], [1, 0], [2, 2], [3, 4]] == ab +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_larger_cells(pre_match): +def test_cell_matches_larger_cells(pre_match, use_scipy): a = as_cell([5, 15, 25, 35, 100]) b = as_cell([1, 12, 100, 80]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert a_ == [2] assert not b_ assert [[0, 0], [1, 1], [3, 3], [4, 2]] == ab +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_threshold(pre_match): +def test_cell_matches_threshold(pre_match, use_scipy): a = as_cell([10, 12, 100, 80]) b = as_cell([0, 5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) assert not a_ assert b_ == [0, 3] assert [[0, 1], [1, 2], [2, 5], [3, 4]] == ab a_, ab, b_ = match_cells( - a, b, threshold=math.sqrt(3) * 11, pre_match=pre_match + a, + b, + threshold=math.sqrt(3) * 11, + pre_match=pre_match, + use_scipy=use_scipy, ) assert a_ == [3] assert b_ == [0, 3, 4] assert [[0, 1], [1, 2], [2, 5]] == ab +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_global_optimum_with_threshold_original_pr(pre_match): +def test_global_optimum_with_threshold_original_pr(pre_match, use_scipy): cells1 = [ Cell((0, 0, 0), Cell.UNKNOWN), Cell((12, 0, 0), Cell.UNKNOWN), @@ -101,7 +110,11 @@ def test_global_optimum_with_threshold_original_pr(pre_match): # without threshold, the global optimum pars points (0, 10), (12, 22) at a # global cost of 20. The other pairing would have cost of 24 missing_c1, good_matches, missing_c2 = match_cells( - cells1, cells2, threshold=np.inf, pre_match=pre_match + cells1, + cells2, + threshold=np.inf, + pre_match=pre_match, + use_scipy=use_scipy, ) assert not missing_c1 assert not missing_c2 @@ -110,79 +123,115 @@ def test_global_optimum_with_threshold_original_pr(pre_match): # with threshold, the previous pairing should not be considered good. # Instead, only (12, 10) is a good match. So while total cost is 24, # we only care about the cost of 2 during the matching algorithm + # BUT, only when not using scipy. With scipy it doesn't account for + # threshold during the matching, so it'll do it after missing_c1, good_matches, missing_c2 = match_cells( - cells1, cells2, threshold=5, pre_match=pre_match + cells1, cells2, threshold=5, pre_match=pre_match, use_scipy=use_scipy ) - # before we added the threshold to match_points, the following applies - # assert missing_c1 == [0, 1] - # assert missing_c2 == [0, 1] - # assert not good_matches - # with threshold in match_points, this is true - as wanted - assert missing_c1 == [ - 0, - ] - assert missing_c2 == [ - 1, - ] - assert good_matches == [[1, 0]] - - + # before we added the threshold to match_points, the following applies to + # both scipy and our own. After the fix, this is only True for scipy + if use_scipy: + assert missing_c1 == [0, 1] + assert missing_c2 == [0, 1] + assert not good_matches + else: + # with threshold in match_points, this is true - as wanted + assert missing_c1 == [ + 0, + ] + assert missing_c2 == [ + 1, + ] + assert good_matches == [[1, 0]] + + +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_rows_greater_than_cols(pre_match): +def test_rows_greater_than_cols(pre_match, use_scipy): with pytest.raises(ValueError): - match_points(np.zeros((5, 3)), np.zeros((4, 3)), pre_match=pre_match) + match_points( + np.zeros((5, 3)), + np.zeros((4, 3)), + pre_match=pre_match, + use_scipy=use_scipy, + ) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_unequal_inputs_shape(pre_match): +def test_unequal_inputs_shape(pre_match, use_scipy): with pytest.raises(ValueError): - match_points(np.zeros((5, 3)), np.zeros((5, 2)), pre_match=pre_match) + match_points( + np.zeros((5, 3)), + np.zeros((5, 2)), + pre_match=pre_match, + use_scipy=use_scipy, + ) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_bad_input_shape(pre_match): +def test_bad_input_shape(pre_match, use_scipy): # has to be 2 dims with pytest.raises(ValueError): - match_points(np.zeros(5), np.zeros(5), pre_match=pre_match) + match_points( + np.zeros(5), np.zeros(5), pre_match=pre_match, use_scipy=use_scipy + ) with pytest.raises(ValueError): match_points( - np.zeros((5, 4, 6)), np.zeros((5, 4, 6)), pre_match=pre_match + np.zeros((5, 4, 6)), + np.zeros((5, 4, 6)), + pre_match=pre_match, + use_scipy=use_scipy, ) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_progress_already_running(pre_match): +def test_progress_already_running(pre_match, use_scipy): a = as_cell([10, 12]) b = as_cell([10, 12]) cell_utils.__progress_update.updater = 1 try: with pytest.raises(TypeError): - match_cells(a, b, pre_match=pre_match) + match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) finally: cell_utils.__progress_update.updater = None +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_distance_too_large(pre_match): +def test_distance_too_large(pre_match, use_scipy): a = np.array([[1, 2, 3]]).T b = np.array([[1, 2, np.inf]]).T with pytest.raises(ValueError): - match_points(a, b, pre_match=pre_match) + match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_contains_identical_points(pre_match): +def test_contains_identical_points(pre_match, use_scipy): a = np.array([[1, 10], [5, 7], [22, 12]]) b = np.array([[5, 7], [7, 1], [21, 10]]) - matching = match_points(a, b, pre_match=pre_match) + matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) assert np.array_equal(matching, [1, 0, 2]) +@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_only_identical_points(pre_match): +def test_contains_only_identical_points(pre_match, use_scipy): a = np.array([[1, 2, 3]]).T b = np.array([[2, 3, 5, 1]]).T - matching = match_points(a, b, pre_match=pre_match) + matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) assert np.array_equal(matching, [3, 0, 1]) + + +@pytest.mark.parametrize("use_scipy", [True, False]) +@pytest.mark.parametrize("pre_match", [True, False]) +def test_contains_no_identical_points(pre_match, use_scipy): + a = np.array([[1, 10], [5, 7], [22, 12]]) + b = np.array([[6, 7], [7, 1], [21, 10]]) + matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + assert np.array_equal(matching, [1, 0, 2]) From fa70e65faa755eb9512d18063b151e924a571731 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Thu, 23 May 2024 01:38:41 -0400 Subject: [PATCH 08/13] Scipy doesn't have callback so don't show progress. --- brainglobe_utils/cells/cells.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 64be3ab..7d97fa4 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -513,15 +513,21 @@ def match_cells( if flip: c1, c2 = c2, c1 - progress = tqdm(desc="Matching cells", total=len(c1), unit="cells") - __progress_update.updater = progress.update + progress = None + if not use_scipy: + # with scipy we don't have callbacks so no updates + progress = tqdm(desc="Matching cells", total=len(c1), unit="cells") + __progress_update.updater = progress.update + # for each index corresponding to c1, returns the index in c2 that matches try: assignment = match_points(c1, c2, threshold, pre_match, use_scipy) - progress.close() finally: __progress_update.updater = None + if progress is not None: + progress.close() + missing_c1, good_matches, missing_c2 = analyze_point_matches( c1, c2, assignment, threshold ) From 041b9daa9c423e2b86c93cce1a064720a0cea2cb Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Mon, 3 Jun 2024 02:34:56 -0400 Subject: [PATCH 09/13] Apply suggestions from code review Co-authored-by: Alessandro Felder --- brainglobe_utils/cells/cells.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 7d97fa4..46a08ec 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -476,7 +476,7 @@ def match_cells( distance is greater than the threshold will be exluded from the matching. pre_match : bool, optional. Defaults to True. - If True, we will (interenally) first efficiently find all the pairs of + If True, we will (internally) first efficiently find all the pairs of `cells` and `others` which are each at the same position in space. Then we run the optimization to find the best matching on the remaining. @@ -769,7 +769,7 @@ def _optimize_pairs( dist = math.sqrt(dist) if dist == np.inf: raise ValueError( - "The distance between point is too large" + "The distance between points is too large" ) if have_threshold and dist > threshold: dist = threshold From f7d190011ebbd271250da44464ce43701bfc3eb5 Mon Sep 17 00:00:00 2001 From: alessandrofelder Date: Mon, 3 Jun 2024 13:04:13 +0100 Subject: [PATCH 10/13] fixture to provide cell data for regression testing --- tests/tests/conftest.py | 21 +++++++++++++++ tests/tests/test_cells/test_matches.py | 37 ++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/tests/tests/conftest.py b/tests/tests/conftest.py index 65a23bd..2ac4e08 100644 --- a/tests/tests/conftest.py +++ b/tests/tests/conftest.py @@ -1,5 +1,6 @@ from pathlib import Path +import pooch import pytest @@ -7,3 +8,23 @@ def data_path(): """Directory storing all test data""" return Path(__file__).parent.parent / "data" + + +@pytest.fixture +def test_data_registry(): + """ + Create a test data registry for BrainGlobe. + + Returns: + pooch.Pooch: The test data registry object. + + """ + registry = pooch.create( + path=pooch.os_cache("brainglobe_test_data"), + base_url="https://gin.g-node.org/BrainGlobe/test-data/raw/master/", + registry={ + "cellfinder/cells-z-1000-1050.xml": None, + "cellfinder/other-cells-z-1000-1050.xml": None, + }, + ) + return registry diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py index add83bd..6bb7657 100644 --- a/tests/tests/test_cells/test_matches.py +++ b/tests/tests/test_cells/test_matches.py @@ -11,6 +11,36 @@ match_cells, match_points, ) +from brainglobe_utils.IO.cells import get_cells + + +@pytest.fixture +def cells_and_other_cells(test_data_registry): + """ + Provides real-life cell coordinates from a CFOS-labelled brain from + two different cellfinder versions (pre- and post cellfinder PR #398). + Intended to be used for regression testing our cell matching code. + + Parameters + ---------- + test_data_registry : Pooch.pooch + The BrainGlobe test data registry. + + Returns + ------- + cell_data : List[Cell] + The loaded cell data. + + """ + cell_data_path = test_data_registry.fetch( + "cellfinder/cells-z-1000-1050.xml" + ) + other_cell_data_path = test_data_registry.fetch( + "cellfinder/other_cells-z-1000-1050.xml" + ) + cell_data = get_cells(cell_data_path) + other_cell_data = get_cells(other_cell_data_path) + return cell_data, other_cell_data def as_cell(x: List[float]): @@ -19,6 +49,13 @@ def as_cell(x: List[float]): return cells +@pytest.mark.xfail +def test_cell_matching_regression(cells_and_other_cells): + cells, other_cells = cells_and_other_cells + # TODO implement cell matching regression test here, then remove xfail + assert False + + @pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) def test_cell_matches_equal_size(pre_match, use_scipy): From baea2ad85c70cbfbd24c0de25764bed0514107e9 Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Mon, 3 Jun 2024 19:33:42 -0400 Subject: [PATCH 11/13] Remove scipy as an option and fill in scipy as a test. --- brainglobe_utils/cells/cells.py | 90 +-------------- tests/tests/test_cells/test_matches.py | 151 +++++++++++++------------ 2 files changed, 86 insertions(+), 155 deletions(-) diff --git a/brainglobe_utils/cells/cells.py b/brainglobe_utils/cells/cells.py index 46a08ec..2062d0c 100644 --- a/brainglobe_utils/cells/cells.py +++ b/brainglobe_utils/cells/cells.py @@ -23,7 +23,6 @@ import numpy as np from numba import njit, objmode -from scipy.optimize import linear_sum_assignment from tqdm import tqdm @@ -439,7 +438,6 @@ def match_cells( other: List[Cell], threshold: float = np.inf, pre_match: bool = True, - use_scipy: bool = False, ) -> Tuple[List[int], List[Tuple[int, int]], List[int]]: """ Given two lists of cells. It finds a pairing of cells from `cells` and @@ -482,11 +480,6 @@ def match_cells( This will significantly speed up the matching, if there are pairs of cells on top of each other in each set. - use_scipy : bool, optional. Defaults to False. - Whether to use scipy `linear_sum_assignment` to find the optimal - matching. Otherwise, we use our own implementation. - - See `match_points` for consideration details. Returns ------- @@ -513,15 +506,12 @@ def match_cells( if flip: c1, c2 = c2, c1 - progress = None - if not use_scipy: - # with scipy we don't have callbacks so no updates - progress = tqdm(desc="Matching cells", total=len(c1), unit="cells") - __progress_update.updater = progress.update + progress = tqdm(desc="Matching cells", total=len(c1), unit="cells") + __progress_update.updater = progress.update # for each index corresponding to c1, returns the index in c2 that matches try: - assignment = match_points(c1, c2, threshold, pre_match, use_scipy) + assignment = match_points(c1, c2, threshold, pre_match) finally: __progress_update.updater = None @@ -816,62 +806,11 @@ def _optimize_pairs( return matches -def _optimize_pairs_scipy( - pos1: np.ndarray, - pos2: np.ndarray, -) -> np.ndarray: - """ - Implements `match_points` using scipy's `linear_sum_assignment`. - - The function has a memory cost of N*M*k*8 bytes. When `K=3` and `N=M`, - approximately 1GB is required for `N=M=6689` points. For `N=M=26.8k` - points, we need 16GB. For `N=M=75.7k` points, we need 128GB. - - Parameters - ---------- - pos1 : np.ndarray - 2D array of NxK. - pos2 : np.ndarray - 2D array of MxK. - - The relationship N <= M must be true and K must be the same for both. - - Returns - ------- - matches : np.ndarray - 1D array of length N, where index i in matches corresponds - to index i in `pos1` and its value is the index in pos2 - of its best match. - """ - # we don't check for boundary conditions, just assert because it should be - # checked by caller (match_points) - n_rows = pos1.shape[0] - n_cols = pos2.shape[0] - assert len(pos1.shape) == 2 - assert len(pos2.shape) == 2 - assert pos1.shape[1] == pos2.shape[1] - assert n_rows <= n_cols - - # Mxk -> M1K - pos1 = pos1[:, np.newaxis, :] - # Nxk -> 1NK - pos2 = pos2[np.newaxis, :, :] - # dist is MNK - dist = pos1 - pos2 - # cost is MN - cost = np.sqrt(np.sum(np.square(dist), axis=2)) - # result is sorted by rows - rows, cols = linear_sum_assignment(cost) - # M <= N, so cols and rows is size M - return cols - - def match_points( pos1: np.ndarray, pos2: np.ndarray, threshold: float = np.inf, pre_match: bool = True, - use_scipy: bool = False, ) -> np.ndarray: """ Given two arrays, each a list of position. For each point in `pos1` it @@ -914,19 +853,6 @@ def match_points( If True, it'll significantly speed up the matching, if there are pairs of points on top of each other across the input lists. - use_scipy : bool, optional. Defaults to False. - Whether to use scipy `linear_sum_assignment` to find the optimal - matching. Otherwise, we use our own implementation. - - Our implementation is very memory efficient. Using scipy, we have e.g. - a memory requirement of approximately 1GB for 6.7k points (`N=M=6.7k`), - 16GB for 26.8k points, and 128GB for 75.7k points. - - If `pre_match` is used, and it eliminates many zero-distance points, - using numpy is feasable. Otherwise, use our implementation. - - Note: When using scipy, we *don't* take the threshold into account - until after the matching is complete. Returns ------- @@ -951,10 +877,7 @@ def match_points( if not pre_match: # do optimization on full inputs - if use_scipy: - return _optimize_pairs_scipy(pos1, pos2) - else: - return _optimize_pairs(pos1, pos2, threshold) + return _optimize_pairs(pos1, pos2, threshold) # extract the indices of zero-pairs and remaining points unpaired1_indices, unpaired2_indices, paired_indices = ( @@ -975,10 +898,7 @@ def match_points( pos2 = pos2[unpaired2_indices] n_rows = pos1.shape[0] - if use_scipy: - matches = _optimize_pairs_scipy(pos1, pos2) - else: - matches = _optimize_pairs(pos1, pos2, threshold) + matches = _optimize_pairs(pos1, pos2, threshold) # map extracted full_matches = np.empty(n_rows + len(paired_indices), dtype=np.int64) diff --git a/tests/tests/test_cells/test_matches.py b/tests/tests/test_cells/test_matches.py index 6bb7657..e14d447 100644 --- a/tests/tests/test_cells/test_matches.py +++ b/tests/tests/test_cells/test_matches.py @@ -3,13 +3,16 @@ import numpy as np import pytest +from scipy.optimize import linear_sum_assignment import brainglobe_utils.cells.cells as cell_utils from brainglobe_utils.cells.cells import ( Cell, + analyze_point_matches, from_numpy_pos, match_cells, match_points, + to_numpy_pos, ) from brainglobe_utils.IO.cells import get_cells @@ -36,7 +39,7 @@ def cells_and_other_cells(test_data_registry): "cellfinder/cells-z-1000-1050.xml" ) other_cell_data_path = test_data_registry.fetch( - "cellfinder/other_cells-z-1000-1050.xml" + "cellfinder/other-cells-z-1000-1050.xml" ) cell_data = get_cells(cell_data_path) other_cell_data = get_cells(other_cell_data_path) @@ -49,73 +52,107 @@ def as_cell(x: List[float]): return cells -@pytest.mark.xfail def test_cell_matching_regression(cells_and_other_cells): cells, other_cells = cells_and_other_cells - # TODO implement cell matching regression test here, then remove xfail - assert False + np_cells = to_numpy_pos(cells) + np_other = to_numpy_pos(other_cells) + + # only run matching on unpaired to reduce computation + unpaired1_indices, unpaired2_indices, paired_indices = ( + cell_utils._find_identical_points(np_cells, np_other) + ) + np_cells = np_cells[unpaired1_indices] + np_other = np_other[unpaired2_indices] + + # happens to be true for this dataset + assert len(np_cells) < len(np_other), "must be true to pass to match" + + # get matches + matches = match_points(np_cells, np_other, pre_match=False) + missing_cells, good, missing_other = analyze_point_matches( + np_cells, np_other, matches + ) + good = np.array(good) + assert not len(missing_cells), "all cells must be matched" + + # get cost + a = np_cells[good[:, 0], :] + b = np_other[good[:, 1], :] + cost_our = np.sum(np.sqrt(np.sum(np.square(a - b), axis=1))) + + # get scipy cost + # Mxk -> M1K + pos1 = np_cells[:, np.newaxis, :] + # Nxk -> 1NK + pos2 = np_other[np.newaxis, :, :] + # dist is MNK + dist = pos1 - pos2 + # cost is MN + cost_mat = np.sqrt(np.sum(np.square(dist), axis=2)) + # result is sorted by rows + rows, cols = linear_sum_assignment(cost_mat) + + cost_scipy = cost_mat[rows, cols].sum() + + assert np.isclose(cost_scipy, cost_our) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_equal_size(pre_match, use_scipy): +def test_cell_matches_equal_size(pre_match): a = as_cell([10, 20, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert not b_ assert [[0, 0], [1, 1], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([5, 15, 25, 35]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert not b_ assert [[0, 1], [1, 0], [2, 2], [3, 3]] == ab -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_larger_other(pre_match, use_scipy): +def test_cell_matches_larger_other(pre_match): a = as_cell([1, 12, 100, 80]) b = as_cell([5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert b_ == [2] assert [[0, 0], [1, 1], [2, 4], [3, 3]] == ab a = as_cell([20, 10, 30, 40]) b = as_cell([11, 22, 39, 42, 41]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert b_ == [3] assert [[0, 1], [1, 0], [2, 2], [3, 4]] == ab -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_larger_cells(pre_match, use_scipy): +def test_cell_matches_larger_cells(pre_match): a = as_cell([5, 15, 25, 35, 100]) b = as_cell([1, 12, 100, 80]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert a_ == [2] assert not b_ assert [[0, 0], [1, 1], [3, 3], [4, 2]] == ab -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_cell_matches_threshold(pre_match, use_scipy): +def test_cell_matches_threshold(pre_match): a = as_cell([10, 12, 100, 80]) b = as_cell([0, 5, 15, 25, 35, 100]) - a_, ab, b_ = match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + a_, ab, b_ = match_cells(a, b, pre_match=pre_match) assert not a_ assert b_ == [0, 3] assert [[0, 1], [1, 2], [2, 5], [3, 4]] == ab @@ -125,16 +162,14 @@ def test_cell_matches_threshold(pre_match, use_scipy): b, threshold=math.sqrt(3) * 11, pre_match=pre_match, - use_scipy=use_scipy, ) assert a_ == [3] assert b_ == [0, 3, 4] assert [[0, 1], [1, 2], [2, 5]] == ab -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_global_optimum_with_threshold_original_pr(pre_match, use_scipy): +def test_global_optimum_with_threshold_original_pr(pre_match): cells1 = [ Cell((0, 0, 0), Cell.UNKNOWN), Cell((12, 0, 0), Cell.UNKNOWN), @@ -151,7 +186,6 @@ def test_global_optimum_with_threshold_original_pr(pre_match, use_scipy): cells2, threshold=np.inf, pre_match=pre_match, - use_scipy=use_scipy, ) assert not missing_c1 assert not missing_c2 @@ -160,115 +194,92 @@ def test_global_optimum_with_threshold_original_pr(pre_match, use_scipy): # with threshold, the previous pairing should not be considered good. # Instead, only (12, 10) is a good match. So while total cost is 24, # we only care about the cost of 2 during the matching algorithm - # BUT, only when not using scipy. With scipy it doesn't account for - # threshold during the matching, so it'll do it after missing_c1, good_matches, missing_c2 = match_cells( - cells1, cells2, threshold=5, pre_match=pre_match, use_scipy=use_scipy + cells1, cells2, threshold=5, pre_match=pre_match ) - # before we added the threshold to match_points, the following applies to - # both scipy and our own. After the fix, this is only True for scipy - if use_scipy: - assert missing_c1 == [0, 1] - assert missing_c2 == [0, 1] - assert not good_matches - else: - # with threshold in match_points, this is true - as wanted - assert missing_c1 == [ - 0, - ] - assert missing_c2 == [ - 1, - ] - assert good_matches == [[1, 0]] - - -@pytest.mark.parametrize("use_scipy", [True, False]) + assert missing_c1 == [ + 0, + ] + assert missing_c2 == [ + 1, + ] + assert good_matches == [[1, 0]] + + @pytest.mark.parametrize("pre_match", [True, False]) -def test_rows_greater_than_cols(pre_match, use_scipy): +def test_rows_greater_than_cols(pre_match): with pytest.raises(ValueError): match_points( np.zeros((5, 3)), np.zeros((4, 3)), pre_match=pre_match, - use_scipy=use_scipy, ) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_unequal_inputs_shape(pre_match, use_scipy): +def test_unequal_inputs_shape(pre_match): with pytest.raises(ValueError): match_points( np.zeros((5, 3)), np.zeros((5, 2)), pre_match=pre_match, - use_scipy=use_scipy, ) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_bad_input_shape(pre_match, use_scipy): +def test_bad_input_shape(pre_match): # has to be 2 dims with pytest.raises(ValueError): - match_points( - np.zeros(5), np.zeros(5), pre_match=pre_match, use_scipy=use_scipy - ) + match_points(np.zeros(5), np.zeros(5), pre_match=pre_match) with pytest.raises(ValueError): match_points( np.zeros((5, 4, 6)), np.zeros((5, 4, 6)), pre_match=pre_match, - use_scipy=use_scipy, ) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_progress_already_running(pre_match, use_scipy): +def test_progress_already_running(pre_match): a = as_cell([10, 12]) b = as_cell([10, 12]) cell_utils.__progress_update.updater = 1 try: with pytest.raises(TypeError): - match_cells(a, b, pre_match=pre_match, use_scipy=use_scipy) + match_cells(a, b, pre_match=pre_match) finally: cell_utils.__progress_update.updater = None -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_distance_too_large(pre_match, use_scipy): +def test_distance_too_large(pre_match): a = np.array([[1, 2, 3]]).T b = np.array([[1, 2, np.inf]]).T with pytest.raises(ValueError): - match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + match_points(a, b, pre_match=pre_match) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_contains_identical_points(pre_match, use_scipy): +def test_contains_identical_points(pre_match): a = np.array([[1, 10], [5, 7], [22, 12]]) b = np.array([[5, 7], [7, 1], [21, 10]]) - matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + matching = match_points(a, b, pre_match=pre_match) assert np.array_equal(matching, [1, 0, 2]) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_contains_only_identical_points(pre_match, use_scipy): +def test_contains_only_identical_points(pre_match): a = np.array([[1, 2, 3]]).T b = np.array([[2, 3, 5, 1]]).T - matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + matching = match_points(a, b, pre_match=pre_match) assert np.array_equal(matching, [3, 0, 1]) -@pytest.mark.parametrize("use_scipy", [True, False]) @pytest.mark.parametrize("pre_match", [True, False]) -def test_contains_no_identical_points(pre_match, use_scipy): +def test_contains_no_identical_points(pre_match): a = np.array([[1, 10], [5, 7], [22, 12]]) b = np.array([[6, 7], [7, 1], [21, 10]]) - matching = match_points(a, b, pre_match=pre_match, use_scipy=use_scipy) + matching = match_points(a, b, pre_match=pre_match) assert np.array_equal(matching, [1, 0, 2]) From bb0ed47ad7e4cb89f3da156b7dc797537193baac Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Wed, 5 Jun 2024 21:49:49 -0400 Subject: [PATCH 12/13] Run cov without numba jit and support GH actions pooch cache. --- .github/workflows/test_and_deploy.yml | 42 +++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 43 insertions(+) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index a6af755..e0301d7 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -26,6 +26,9 @@ jobs: needs: [linting, manifest] name: ${{ matrix.os }} py${{ matrix.python-version }} runs-on: ${{ matrix.os }} + env: + # used on unix by pooch for cache dir + XDG_CACHE_HOME: "~/.pooch_cache" strategy: matrix: # Run all supported Python versions on linux @@ -41,6 +44,13 @@ jobs: python-version: "3.11" steps: + - name: Cache pooch data + uses: actions/cache@v4 + with: + path: "~/.pooch_cache" + # hash on conftest in case url changes + key: ${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/conftest.py') }} + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: @@ -55,6 +65,38 @@ jobs: python-version: ${{ matrix.python-version }} secret-codecov-token: ${{ secrets.CODECOV_TOKEN }} + test_numba_disabled: + needs: [ linting, manifest ] + name: Run tests with numba disabled + runs-on: ubuntu-latest + env: + NUMBA_DISABLE_JIT: "1" + # used on unix by pooch for cache dir + XDG_CACHE_HOME: "~/.pooch_cache" + + steps: + - name: Cache pooch data + uses: actions/cache@v4 + with: + path: "~/.pooch_cache" + # hash on conftest in case url changes + key: ${{ runner.os }}-3.11-${{ hashFiles('**/conftest.py') }} + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.11" + # these libraries enable testing on Qt on linux + - uses: pyvista/setup-headless-display-action@v2 + with: + qt: true + # Run test suite with numba disabled + - uses: neuroinformatics-unit/actions/test@v2 + with: + python-version: "3.11" + secret-codecov-token: ${{ secrets.CODECOV_TOKEN }} + codecov-flags: "numba" + build_sdist_wheels: name: Build source distribution needs: [test] diff --git a/pyproject.toml b/pyproject.toml index e7a83cd..bd89129 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dev = [ "setuptools_scm", "superqt", "tox", + "pooch", ] From acf0d6387a1dcb466963296053c79d5949aeb8bb Mon Sep 17 00:00:00 2001 From: Matt Einhorn Date: Thu, 6 Jun 2024 00:39:07 -0400 Subject: [PATCH 13/13] Use proper env variable for pooch cache. --- .github/workflows/test_and_deploy.yml | 9 +++++---- pyproject.toml | 1 + tests/tests/conftest.py | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index e0301d7..7f94ffa 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -27,8 +27,9 @@ jobs: name: ${{ matrix.os }} py${{ matrix.python-version }} runs-on: ${{ matrix.os }} env: - # used on unix by pooch for cache dir - XDG_CACHE_HOME: "~/.pooch_cache" + # pooch cache dir + BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache" + strategy: matrix: # Run all supported Python versions on linux @@ -71,8 +72,8 @@ jobs: runs-on: ubuntu-latest env: NUMBA_DISABLE_JIT: "1" - # used on unix by pooch for cache dir - XDG_CACHE_HOME: "~/.pooch_cache" + # pooch cache dir + BRAINGLOBE_TEST_DATA_DIR: "~/.pooch_cache" steps: - name: Cache pooch data diff --git a/pyproject.toml b/pyproject.toml index bd89129..78f54f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,4 +145,5 @@ passenv = DISPLAY XAUTHORITY PYVISTA_OFF_SCREEN + BRAINGLOBE_TEST_DATA_DIR """ diff --git a/tests/tests/conftest.py b/tests/tests/conftest.py index 2ac4e08..f0a78cc 100644 --- a/tests/tests/conftest.py +++ b/tests/tests/conftest.py @@ -26,5 +26,6 @@ def test_data_registry(): "cellfinder/cells-z-1000-1050.xml": None, "cellfinder/other-cells-z-1000-1050.xml": None, }, + env="BRAINGLOBE_TEST_DATA_DIR", ) return registry