From d6b479fcdf6cebf7c369bd6e1d718d8635f34007 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 14:59:11 -0500 Subject: [PATCH 01/24] deps: add duckdb and nix stuff --- flake.lock | 81 +++++++++++++++++++++++++++++++++++++++ flake.nix | 102 +++++++++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 3 +- 3 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 flake.lock create mode 100644 flake.nix diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..d8f4418 --- /dev/null +++ b/flake.lock @@ -0,0 +1,81 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": [ + "systems" + ] + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1739357830, + "narHash": "sha256-9xim3nJJUFbVbJCz48UP4fGRStVW5nv4VdbimbKxJ3I=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "0ff09db9d034a04acd4e8908820ba0b410d7a33a", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-24.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_master": { + "locked": { + "lastModified": 1739550317, + "narHash": "sha256-Tdrwxe81xIMTAzX2lQaoU+Sz3JDTHtUEzLFYBv84IB0=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "5866d80b157e31becb14f9084ba560b818f7e989", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "master", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "nixpkgs_master": "nixpkgs_master", + "systems": "systems" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..2401890 --- /dev/null +++ b/flake.nix @@ -0,0 +1,102 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11"; + nixpkgs_master.url = "github:NixOS/nixpkgs/master"; + systems.url = "github:nix-systems/default"; + flake-utils.url = "github:numtide/flake-utils"; + flake-utils.inputs.systems.follows = "systems"; + }; + + outputs = + { + self, + nixpkgs, + flake-utils, + ... + }@inputs: + flake-utils.lib.eachDefaultSystem ( + system: + let + pkgs = import nixpkgs { + inherit system; + config.allowUnfree = true; + config.cudaSupport = true; + }; + + mpkgs = import inputs.nixpkgs_master { + inherit system; + config.allowUnfree = true; + config.cudaSupport = true; + }; + + libList = + [ + # Add needed packages here + pkgs.stdenv.cc.cc + pkgs.libGL + pkgs.glib + pkgs.zlib + ] + ++ pkgs.lib.optionals pkgs.stdenv.isLinux ( + with pkgs; + [ + # cudatoolkit + + # This is required for most app that uses graphics api + # linuxPackages.nvidia_x11 + ] + ); + in + with pkgs; + { + devShells = { + default = + let + python_with_pkgs = pkgs.python311.withPackages (pp: [ + # Add python pkgs here that you need from nix repos + ruff + ]); + in + mkShell { + NIX_LD = runCommand "ld.so" { } '' + ln -s "$(cat '${pkgs.stdenv.cc}/nix-support/dynamic-linker')" $out + ''; + NIX_LD_LIBRARY_PATH = lib.makeLibraryPath libList; + packages = [ + python_with_pkgs + python311Packages.venvShellHook + # We # We now recommend to use uv for package management inside nix env + mpkgs.uv + + # duckdb + ] ++ libList; + venvDir = "./.venv"; + postVenvCreation = '' + unset SOURCE_DATE_EPOCH + ''; + postShellHook = '' + unset SOURCE_DATE_EPOCH + ''; + shellHook = '' + export LD_LIBRARY_PATH=$NIX_LD_LIBRARY_PATH:$LD_LIBRARY_PATH + export PYTHON_KEYRING_BACKEND=keyring.backends.fail.Keyring + runHook venvShellHook + uv pip install -e . + export PYTHONPATH=${python_with_pkgs}/${python_with_pkgs.sitePackages}:$PYTHONPATH + ''; + }; + }; + } + ); +} +# Things one might need for debugging or adding compatibility +# export CUDA_PATH=${pkgs.cudaPackages.cudatoolkit} +# export LD_LIBRARY_PATH=${pkgs.cudaPackages.cuda_nvrtc}/lib +# export EXTRA_LDFLAGS="-L/lib -L${pkgs.linuxPackages.nvidia_x11}/lib" +# export EXTRA_CCFLAGS="-I/usr/include" + +# Data syncthing commands +# syncthing cli show system | jq .myID +# syncthing cli config devices add --device-id $DEVICE_ID_B +# syncthing cli config folders $FOLDER_ID devices add --device-id $DEVICE_ID_B +# syncthing cli config devices $DEVICE_ID_A auto-accept-folders set true diff --git a/pyproject.toml b/pyproject.toml index 92728a9..c846c09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,8 @@ authors = [ dependencies = [ "pandas", "tqdm", - "statsmodels" + "statsmodels", + "duckdb>=1.2.0", ] [project.optional-dependencies] From 76e2d113198c600274e7aed287315e8c560bf945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 14:59:31 -0500 Subject: [PATCH 02/24] feat: add temporary dev file --- src/copairs/dev.py | 67 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 src/copairs/dev.py diff --git a/src/copairs/dev.py b/src/copairs/dev.py new file mode 100644 index 0000000..553de23 --- /dev/null +++ b/src/copairs/dev.py @@ -0,0 +1,67 @@ +"""Helper functions for testing.""" + +import pandas as pd +import duckdb + +Seed = 0 +# Cols are c, p and w +sameby = ["c"] +diffby = [] + +def simulate_plates(n_compounds, n_replicates, plate_size): + """Round robin creation of platemaps.""" + total = n_compounds * n_replicates + + compounds = [] + plates = [] + wells = [] + for i in range(total): + compound_id = i % n_compounds + well_id = i % plate_size + plate_id = i // plate_size + compounds.append(f"c{compound_id}") + plates.append(f"p{plate_id}") + wells.append(f"w{well_id}") + +def find_pairs(dframe, sameby, diffby): + # Assumes sameby or diffby is not empty + df = dframe.reset_index() + with duckdb.connect("main"): + pos_suffix = [f"AND A.{x} = B.{x}" for x in sameby[1:]] + neg_suffix = [f"AND NOT A.{x} = B.{x}" for x in diffby] + string = ( + # f"SELECT list(index),list(index_1),{','.join('first(' + x +')' for x in sameby)} FROM (" + f"SELECT {','.join(['CAST(A.' + x +') AS ' for x in sameby])},A.index,B.index " + 'FROM df A ' + 'JOIN df B ' + f"ON A.{sameby[0]} = B.{sameby[0]}" + f" {' '.join((*pos_suffix, *neg_suffix))} " + # f" {' '.join((*pos_suffix, *neg_suffix))})" + # f"GROUP BY {','.join(sameby)}" + ) + tmp = duckdb.sql(string) + + dframe = pd.DataFrame({"c": compounds, "p": plates, "w": wells}) + return dframe + +# Gen data +%timeit dframe = simulate_plates(n_compounds=15000, n_replicates=20, plate_size=384) + +# Load matcher +%timeit matcher = Matcher(dframe, dframe.columns, seed=SEED) + +# Evaluate +# pairs_dict = matcher.get_all_pairs(sameby, diffby) +# %timeit pairs_dict2 = find_pairs(dframe,sameby, diffby) + + +# Compounds = 15000 +#%timeit pairs_dict = matcher.get_all_pairs(sameby, diffby) +# 428 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) +# %timeit pairs_dict2 = find_pairs(dframe,sameby, diffby) +# 30.1 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) +# Compounds = 150000 +# %timeit pairs_dict = matcher.get_all_pairs(sameby, diffby) +# 4.59 s ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) +# %timeit pairs_dict2 = find_pairs(dframe,sameby, diffby) +# 151 ms ± 612 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) From c8380d88e19331c16e728877e413e3b466fef920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 15:30:39 -0500 Subject: [PATCH 03/24] dev: clean up and add test --- src/copairs/dev.py | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/src/copairs/dev.py b/src/copairs/dev.py index 553de23..2bff799 100644 --- a/src/copairs/dev.py +++ b/src/copairs/dev.py @@ -7,6 +7,9 @@ # Cols are c, p and w sameby = ["c"] diffby = [] +n_compounds = 15000 +n_replicates = 20 +plate_size = 384 def simulate_plates(n_compounds, n_replicates, plate_size): """Round robin creation of platemaps.""" @@ -22,7 +25,9 @@ def simulate_plates(n_compounds, n_replicates, plate_size): compounds.append(f"c{compound_id}") plates.append(f"p{plate_id}") wells.append(f"w{well_id}") - + + dframe = pd.DataFrame({"c": compounds, "p": plates, "w": wells}) + return dframe def find_pairs(dframe, sameby, diffby): # Assumes sameby or diffby is not empty df = dframe.reset_index() @@ -30,29 +35,33 @@ def find_pairs(dframe, sameby, diffby): pos_suffix = [f"AND A.{x} = B.{x}" for x in sameby[1:]] neg_suffix = [f"AND NOT A.{x} = B.{x}" for x in diffby] string = ( - # f"SELECT list(index),list(index_1),{','.join('first(' + x +')' for x in sameby)} FROM (" - f"SELECT {','.join(['CAST(A.' + x +') AS ' for x in sameby])},A.index,B.index " + f"SELECT {','.join(['A.' + x for x in sameby])},A.index,B.index " 'FROM df A ' 'JOIN df B ' f"ON A.{sameby[0]} = B.{sameby[0]}" f" {' '.join((*pos_suffix, *neg_suffix))} " - # f" {' '.join((*pos_suffix, *neg_suffix))})" - # f"GROUP BY {','.join(sameby)}" ) + # tmp = duckdb.sql(f"SELECT * WHERE(c = c0) FROM ({string})") tmp = duckdb.sql(string) + return tmp - dframe = pd.DataFrame({"c": compounds, "p": plates, "w": wells}) - return dframe + return None # Gen data -%timeit dframe = simulate_plates(n_compounds=15000, n_replicates=20, plate_size=384) +dframe = simulate_plates(n_compounds, n_replicates, plate_size) # Load matcher -%timeit matcher = Matcher(dframe, dframe.columns, seed=SEED) +matcher = Matcher(dframe, dframe.columns, seed=SEED) # Evaluate -# pairs_dict = matcher.get_all_pairs(sameby, diffby) -# %timeit pairs_dict2 = find_pairs(dframe,sameby, diffby) +# %timeit pairs_dict = matcher.get_all_pairs(sameby, diffby) +duckdb_results = find_pairs(dframe,sameby, diffby) +"""Assert the pairs are valid.""" +for _, id1, id2 in duckdb_results.fetchall(): + for col in sameby: + assert dframe.loc[id1, col] == dframe.loc[id2, col] + for col in diffby: + assert dframe.loc[id1, col] != dframe.loc[id2, col] # Compounds = 15000 From f68c34c21b0fa49f334d97874dd38a69697ea65f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 16:27:57 -0500 Subject: [PATCH 04/24] index on duckdb: c8380d8 dev: clean up and add test From f9af6fcc5670018deac9a7100905ab4bb4f76e23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 16:45:14 -0500 Subject: [PATCH 05/24] index on (no branch): 8b492ec On duckdb: c8380d8 dev: clean up and add test From 517045b46060cc09303f6efc947c9f43a7013cca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 18:09:24 -0500 Subject: [PATCH 06/24] deps: add ipdb and jupyter as dev deps. - ipdb: better debugger - jupyter: Provides autoreload function --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fc1738d..0037a7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,11 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["ruff"] +dev = [ + "ipdb>=0.13.13", + "jupyter>=1.1.1", + "ruff", +] plot = ["plotly"] test = ["scikit-learn", "pytest"] demo = ["notebook", "matplotlib"] From b7e296e04dd3cd44402776244e00ac2e59eea681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 18:10:14 -0500 Subject: [PATCH 07/24] fix(find_pairs): The new function passes tests and goes brrr --- src/copairs/matching.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/copairs/matching.py b/src/copairs/matching.py index 2b28633..e0010f6 100644 --- a/src/copairs/matching.py +++ b/src/copairs/matching.py @@ -464,25 +464,21 @@ def filter_fn(x): return set(x) not in pairs return {None: list(filter(filter_fn, all_pairs))} - -def find_pairs(dframe, sameby:list[str], diffby:list[str])->np.ndarray: - """ - Find the indices of pairs in which share the same value in `sameby` columns but not on `diffby` columns. - This assumes that sameby is not empty - - """ + + +def find_pairs(dframe, sameby: list[str], diffby: list[str], inside=True) -> np.ndarray: + """Find the indices pairs sharing values in `sameby` columns but not on `diffby` columns.""" df = dframe.reset_index() with duckdb.connect("main"): - pos_suffix = [f"AND A.{x} = B.{x}" for x in sameby[1:]] - neg_suffix = [f"AND NOT A.{x} = B.{x}" for x in diffby] + pos_suffix = [f"A.{x} = B.{x}" for x in sameby] + neg_suffix = [f"NOT A.{x} = B.{x}" for x in diffby] string = ( - # f"SELECT A.index,B.index,{','.join(['A.' + x for x in sameby])} " - f"SELECT A.index,B.index " - 'FROM df A ' - 'JOIN df B ' - f"ON A.{sameby[0]} = B.{sameby[0]}" - f" {' '.join((*pos_suffix, *neg_suffix))} " + f"SELECT A.index,B.index" + " FROM df A" + " JOIN df B" + " ON A.index < B.index" # Ensures only one of (a,b)/(b,a) and no (a,a) + f" AND {' AND '.join((*pos_suffix, *neg_suffix))}" ) index_d = duckdb.sql(string).fetchnumpy() - return np.array((index_d["index"], index_d["index_1"]), dtype=np.uint32).T + return np.array((index_d["index"], index_d["index_1"]), dtype=np.uint32).T From 71420cb579fa2beef39c1eb60a222f3ba9b9b9f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 18:21:55 -0500 Subject: [PATCH 08/24] clean: remove dev.py --- src/copairs/dev.py | 62 ---------------------------------------------- 1 file changed, 62 deletions(-) delete mode 100644 src/copairs/dev.py diff --git a/src/copairs/dev.py b/src/copairs/dev.py deleted file mode 100644 index 040adfa..0000000 --- a/src/copairs/dev.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Helper functions for testing.""" - -import pandas as pd - -from copairs.matching import Matcher, find_pairs - -SEED = 42 - -# Cols are c, p and w -sameby = ["c", "w"] -diffby = [] -n_compounds = 150000 -n_replicates = 20 -plate_size = 384 - -def simulate_plates(n_compounds, n_replicates, plate_size): - """Round robin creation of platemaps.""" - total = n_compounds * n_replicates - - compounds = [] - plates = [] - wells = [] - for i in range(total): - compound_id = i % n_compounds - well_id = i % plate_size - plate_id = i // plate_size - compounds.append(f"c{compound_id}") - plates.append(f"p{plate_id}") - wells.append(f"w{well_id}") - - dframe = pd.DataFrame({"c": compounds, "p": plates, "w": wells}) - return dframe - -# Gen data -dframe = simulate_plates(n_compounds, n_replicates, plate_size) - -# Load matcher -matcher = Matcher(dframe, dframe.columns, seed=SEED) - -""" -#Evaluate correctness -pairs = find_pairs(dframe,sameby, diffby) -for id1, id2, *_ in pairs.fetchall(): - for col in sameby: - assert dframe.loc[id1, col] == dframe.loc[id2, col] - for col in diffby: - assert dframe.loc[id1, col] != dframe.loc[id2, col] -""" - -""" -# Current -%timeit matcher.get_all_pairs(sameby, diffby) -1500: 862 ms ± 4.33 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -15000: 8.7 s ± 20.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) -150000: 1min 30s ± 312 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) - -# Duckdb -%timeit find_pairs(dframe,sameby, diffby) -1500: 13.8 ms ± 67.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) -15000: 27.8 ms ± 468 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) -150000: 147 ms ± 937 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) -""" From 23874c7eff2f072dff52cf0858a3d6b3e2835d62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 18:26:41 -0500 Subject: [PATCH 09/24] clean: remove commented code on average_precision --- src/copairs/map/average_precision.py | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/src/copairs/map/average_precision.py b/src/copairs/map/average_precision.py index 8ddec24..39bb054 100644 --- a/src/copairs/map/average_precision.py +++ b/src/copairs/map/average_precision.py @@ -8,7 +8,7 @@ import pandas as pd from copairs import compute -from copairs.matching import find_pairs, UnpairedException +from copairs.matching import UnpairedException, find_pairs from .filter import evaluate_and_filter, flatten_str_list, validate_pipeline_input @@ -53,12 +53,10 @@ def build_rank_lists( Array of counts indicating how many times each profile index appears in the rank lists. """ # Combine relevance labels: 1 for positive pairs, 0 for negative pairs - labels = np.concatenate( - [ - np.ones(pos_pairs.size, dtype=np.uint32), - np.zeros(neg_pairs.size, dtype=np.uint32), - ] - ) + labels = np.concatenate([ + np.ones(pos_pairs.size, dtype=np.uint32), + np.zeros(neg_pairs.size, dtype=np.uint32), + ]) # Flatten positive and negative pair indices for ranking ix = np.concatenate([pos_pairs.ravel(), neg_pairs.ravel()]) @@ -179,7 +177,6 @@ def average_precision( meta = meta.reset_index(drop=True).copy() logger.info("Indexing metadata...") - # matcher = (meta, columns, seed=0) # Identify positive pairs based on `pos_sameby` and `pos_diffby` logger.info("Finding positive pairs...") @@ -188,13 +185,6 @@ def average_precision( if len(pos_pairs) == 0: raise UnpairedException("Unable to find positive pairs.") - # # Convert positive pairs to a NumPy array for efficient computation - # pos_pairs = np.fromiter( - # itertools.chain.from_iterable(pos_pairs.values()), - # dtype=np.dtype((np.uint32, 2)), - # count=pos_total, - # ) - # Identify negative pairs based on `neg_sameby` and `neg_diffby` logger.info("Finding negative pairs...") neg_pairs = find_pairs(meta, sameby=neg_sameby, diffby=neg_diffby) @@ -202,13 +192,6 @@ def average_precision( if len(neg_pairs) == 0: raise UnpairedException("Unable to find negative pairs.") - # Convert negative pairs to a NumPy array for efficient computation - # neg_pairs = np.fromiter( - # itertools.chain.from_iterable(neg_pairs.values()), - # dtype=np.dtype((np.uint32, 2)), - # count=neg_total, - # ) - # Compute similarities for positive pairs logger.info("Computing positive similarities...") pos_sims = distance_fn(feats, pos_pairs, batch_size) From 6114b17bae922730ee411eca4831681a10133790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 18:29:03 -0500 Subject: [PATCH 10/24] style: ruff format --- src/copairs/map/average_precision.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/copairs/map/average_precision.py b/src/copairs/map/average_precision.py index 39bb054..36c5323 100644 --- a/src/copairs/map/average_precision.py +++ b/src/copairs/map/average_precision.py @@ -53,10 +53,12 @@ def build_rank_lists( Array of counts indicating how many times each profile index appears in the rank lists. """ # Combine relevance labels: 1 for positive pairs, 0 for negative pairs - labels = np.concatenate([ - np.ones(pos_pairs.size, dtype=np.uint32), - np.zeros(neg_pairs.size, dtype=np.uint32), - ]) + labels = np.concatenate( + [ + np.ones(pos_pairs.size, dtype=np.uint32), + np.zeros(neg_pairs.size, dtype=np.uint32), + ] + ) # Flatten positive and negative pair indices for ranking ix = np.concatenate([pos_pairs.ravel(), neg_pairs.ravel()]) From 1dcb0dc0d4c33f5414db7cc09f1dc3f3a1bede35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 18:32:08 -0500 Subject: [PATCH 11/24] fix(py38): remove typing that breaks CI checks for py38 --- src/copairs/matching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/copairs/matching.py b/src/copairs/matching.py index e0010f6..c25036c 100644 --- a/src/copairs/matching.py +++ b/src/copairs/matching.py @@ -466,7 +466,7 @@ def filter_fn(x): return {None: list(filter(filter_fn, all_pairs))} -def find_pairs(dframe, sameby: list[str], diffby: list[str], inside=True) -> np.ndarray: +def find_pairs(dframe, sameby, diffby, inside=True) -> np.ndarray: """Find the indices pairs sharing values in `sameby` columns but not on `diffby` columns.""" df = dframe.reset_index() with duckdb.connect("main"): From b7a4673c9c3ba5d4d443c9e245d1af33d7d584f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 18:36:43 -0500 Subject: [PATCH 12/24] clean: delete sneaky leftover comments --- src/copairs/map/average_precision.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/copairs/map/average_precision.py b/src/copairs/map/average_precision.py index 36c5323..f1782c5 100644 --- a/src/copairs/map/average_precision.py +++ b/src/copairs/map/average_precision.py @@ -183,14 +183,12 @@ def average_precision( # Identify positive pairs based on `pos_sameby` and `pos_diffby` logger.info("Finding positive pairs...") pos_pairs = find_pairs(meta, sameby=pos_sameby, diffby=pos_diffby) - # pos_total = sum(len(p) for p in pos_pairs.values()) if len(pos_pairs) == 0: raise UnpairedException("Unable to find positive pairs.") # Identify negative pairs based on `neg_sameby` and `neg_diffby` logger.info("Finding negative pairs...") neg_pairs = find_pairs(meta, sameby=neg_sameby, diffby=neg_diffby) - # neg_total = sum(len(p) for p in neg_pairs.values()) if len(neg_pairs) == 0: raise UnpairedException("Unable to find negative pairs.") From c237287f9386a0b47aefb98090487f9fb1f8ae23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 20:14:13 -0500 Subject: [PATCH 13/24] fix: cover arg errors inside find_pairs --- src/copairs/matching.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/copairs/matching.py b/src/copairs/matching.py index c25036c..e360875 100644 --- a/src/copairs/matching.py +++ b/src/copairs/matching.py @@ -468,6 +468,18 @@ def filter_fn(x): def find_pairs(dframe, sameby, diffby, inside=True) -> np.ndarray: """Find the indices pairs sharing values in `sameby` columns but not on `diffby` columns.""" + + if isinstance(sameby, str): + sameby = (sameby,) + if isinstance(diffby, str): + sameby = (diffby,) + + if not (len(sameby) or len(diffby)): + raise ValueError("at least one should be provided") + + if len(set(sameby).intersection(diffby)): + raise ValueError("sameby and diffby must be disjoint lists") + df = dframe.reset_index() with duckdb.connect("main"): pos_suffix = [f"A.{x} = B.{x}" for x in sameby] From 0812f1784b6580caff1f55f0e155e1843a16a6c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 20:14:31 -0500 Subject: [PATCH 14/24] tests: replace most Matcher tests with find_pairs --- tests/test_matching.py | 44 ++++++++++++++++++------------------------ 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 91c494f..e2f4a99 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -7,6 +7,7 @@ import pytest from copairs import Matcher +from copairs.matching import find_pairs from tests.helpers import create_dframe, simulate_plates, simulate_random_dframe SEED = 0 @@ -61,11 +62,12 @@ def get_naive_pairs(dframe: pd.DataFrame, sameby, diffby): return pairs -def check_naive(dframe, matcher: Matcher, sameby, diffby): +def check_naive(dframe, sameby, diffby): """Check Matcher and naive generate same pairs.""" gt_pairs = get_naive_pairs(dframe, sameby, diffby) - vals = matcher.get_all_pairs(sameby, diffby) - vals = sum(vals.values(), []) + # vals = matcher.get_all_pairs(sameby, diffby) + vals = find_pairs(dframe, sameby, diffby) + # vals = sum(vals.values(), []) vals = pd.DataFrame(vals, columns=["index_x", "index_y"]) vals = vals.sort_values(["index_x", "index_y"]).reset_index(drop=True) vals = set(vals.apply(frozenset, axis=1)) @@ -76,8 +78,7 @@ def check_naive(dframe, matcher: Matcher, sameby, diffby): def check_simulated_data(length, vocab_size, sameby, diffby, rng): """Test sample of valid pairs from a simulated dataset.""" dframe = simulate_random_dframe(length, vocab_size, sameby, diffby, rng) - matcher = Matcher(dframe, dframe.columns, seed=SEED) - check_naive(dframe, matcher, sameby, diffby) + check_naive(dframe, sameby, diffby) def test_stress_simulated_data(): @@ -101,51 +102,44 @@ def test_stress_simulated_data(): def test_empty_sameby(): """Test query without sameby.""" dframe = create_dframe(3, 10) - matcher = Matcher(dframe, dframe.columns, seed=SEED) - check_naive(dframe, matcher, sameby=[], diffby=["w", "c"]) - check_naive(dframe, matcher, sameby=[], diffby=["w"]) + check_naive(dframe, sameby=[], diffby=["w", "c"]) + check_naive(dframe, sameby=[], diffby=["w"]) def test_empty_diffby(): """Test query without diffby.""" dframe = create_dframe(3, 10) - matcher = Matcher(dframe, dframe.columns, seed=SEED) - matcher.get_all_pairs(["c"], []) - check_naive(dframe, matcher, sameby=["c"], diffby=[]) - check_naive(dframe, matcher, sameby=["w", "c"], diffby=[]) + check_naive(dframe, sameby=["c"], diffby=[]) + check_naive(dframe, sameby=["w", "c"], diffby=[]) def test_raise_distjoint(): """Test check for disjoint sameby and diffby.""" dframe = create_dframe(3, 10) - matcher = Matcher(dframe, dframe.columns, seed=SEED) with pytest.raises(ValueError, match="must be disjoint lists"): - matcher.get_all_pairs("c", ["w", "c"]) + find_pairs(dframe, "c", ["w", "c"]) def test_raise_no_params(): """Test check for at least one of sameby and diffby.""" dframe = create_dframe(3, 10) - matcher = Matcher(dframe, dframe.columns, seed=SEED) with pytest.raises(ValueError, match="at least one should be provided"): - matcher.get_all_pairs([], []) + find_pairs(dframe, [], []) -def assert_sameby_diffby(dframe: pd.DataFrame, pairs_dict: dict, sameby, diffby): +def assert_sameby_diffby(dframe: pd.DataFrame, pairs: dict, sameby, diffby): """Assert the pairs are valid.""" - for _, pairs in pairs_dict.items(): - for id1, id2 in pairs: - for col in sameby: - assert dframe.loc[id1, col] == dframe.loc[id2, col] - for col in diffby: - assert dframe.loc[id1, col] != dframe.loc[id2, col] + for id1, id2 in pairs: + for col in sameby: + assert dframe.loc[id1, col] == dframe.loc[id2, col] + for col in diffby: + assert dframe.loc[id1, col] != dframe.loc[id2, col] def test_simulate_plates_mult_sameby_large(): """Test matcher successfully complete analysis of a large dataset.""" dframe = simulate_plates(n_compounds=15000, n_replicates=20, plate_size=384) - matcher = Matcher(dframe, dframe.columns, seed=SEED) sameby = ["c", "w"] diffby = ["p"] - pairs_dict = matcher.get_all_pairs(sameby, diffby) + pairs_dict = find_pairs(dframe, sameby, diffby) assert_sameby_diffby(dframe, pairs_dict, sameby, diffby) From c9365f73e990a0966a9adb3e6654edd5e6e509a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 21:05:10 -0500 Subject: [PATCH 15/24] feat(find_pairs): support set complement --- src/copairs/matching.py | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/src/copairs/matching.py b/src/copairs/matching.py index e360875..dfd6d29 100644 --- a/src/copairs/matching.py +++ b/src/copairs/matching.py @@ -466,31 +466,41 @@ def filter_fn(x): return {None: list(filter(filter_fn, all_pairs))} -def find_pairs(dframe, sameby, diffby, inside=True) -> np.ndarray: - """Find the indices pairs sharing values in `sameby` columns but not on `diffby` columns.""" - - if isinstance(sameby, str): - sameby = (sameby,) - if isinstance(diffby, str): - sameby = (diffby,) +def find_pairs(dframe, sameby, diffby, rev=False) -> np.ndarray: + """Find the indices pairs sharing values in `sameby` columns but not on `diffby` columns. + + `rev` reverses same and diff, which means that we get the complement + """ + sameby, diffby = _validate(sameby, diffby) - if not (len(sameby) or len(diffby)): - raise ValueError("at least one should be provided") - if len(set(sameby).intersection(diffby)): - raise ValueError("sameby and diffby must be disjoint lists") - + raise ValueError("sameby and diffby must be disjoint lists") + df = dframe.reset_index() with duckdb.connect("main"): - pos_suffix = [f"A.{x} = B.{x}" for x in sameby] - neg_suffix = [f"NOT A.{x} = B.{x}" for x in diffby] + group_1, group_2 = [ + [f"{('', 'NOT')[i - rev]} A.{x} = B.{x}" for x in y] + for i, y in enumerate((sameby, diffby)) + ] string = ( f"SELECT A.index,B.index" " FROM df A" " JOIN df B" " ON A.index < B.index" # Ensures only one of (a,b)/(b,a) and no (a,a) - f" AND {' AND '.join((*pos_suffix, *neg_suffix))}" + f" AND {' AND '.join((*group_1, *group_2))}" ) index_d = duckdb.sql(string).fetchnumpy() return np.array((index_d["index"], index_d["index_1"]), dtype=np.uint32).T + + +def _validate(sameby, diffby): + if isinstance(sameby, str): + sameby = (sameby,) + if isinstance(diffby, str): + sameby = (diffby,) + + if not (len(sameby) or len(diffby)): + raise ValueError("at least one should be provided") + + return sameby, diffby From 011baa8fedb8fd667fb2ac8d1c14a8086d25f35b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 21:05:31 -0500 Subject: [PATCH 16/24] tests: adjust test to replace Matcher with find_pairs --- tests/test_matching.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index e2f4a99..4484c14 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -15,9 +15,11 @@ def run_stress_sample_null(dframe, num_pairs): """Assert every generated null pair does not match any column.""" - matcher = Matcher(dframe, dframe.columns, seed=SEED) + null_pair = find_pairs(dframe, dframe.columns, [], rev=True) + randint = np.random.randint(len(null_pair)) + sample = null_pair[randint] for _ in range(num_pairs): - id1, id2 = matcher.sample_null_pair(dframe.columns) + id1, id2 = sample row1 = dframe.loc[id1] row2 = dframe.loc[id2] assert (row1 != row2).all() From b7004637de686d53eb005a91af3f25e771ffa102 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 21:09:57 -0500 Subject: [PATCH 17/24] fix(test): sample all the values --- tests/test_matching.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 4484c14..87f80ce 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -16,10 +16,9 @@ def run_stress_sample_null(dframe, num_pairs): """Assert every generated null pair does not match any column.""" null_pair = find_pairs(dframe, dframe.columns, [], rev=True) - randint = np.random.randint(len(null_pair)) - sample = null_pair[randint] - for _ in range(num_pairs): - id1, id2 = sample + randints = np.random.randint(len(null_pair), size=num_pairs) + for i in randints: + id1, id2 = null_pair[i] row1 = dframe.loc[id1] row2 = dframe.loc[id2] assert (row1 != row2).all() From b79f798db156a5a712d5bbd403282c01480ed9f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Fri, 14 Feb 2025 21:14:13 -0500 Subject: [PATCH 18/24] clean: (finally) remove matcher from test_matching --- tests/test_matching.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 87f80ce..62bb794 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -6,7 +6,6 @@ import pandas as pd import pytest -from copairs import Matcher from copairs.matching import find_pairs from tests.helpers import create_dframe, simulate_plates, simulate_random_dframe From b34158ae06a0355c89a715ffe8287e431e73cab6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Tue, 18 Feb 2025 20:10:24 -0500 Subject: [PATCH 19/24] change(matching): replace local database for in-memory --- src/copairs/matching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/copairs/matching.py b/src/copairs/matching.py index dfd6d29..032f2bc 100644 --- a/src/copairs/matching.py +++ b/src/copairs/matching.py @@ -477,7 +477,7 @@ def find_pairs(dframe, sameby, diffby, rev=False) -> np.ndarray: raise ValueError("sameby and diffby must be disjoint lists") df = dframe.reset_index() - with duckdb.connect("main"): + with duckdb.connect(":memory:"): group_1, group_2 = [ [f"{('', 'NOT')[i - rev]} A.{x} = B.{x}" for x in y] for i, y in enumerate((sameby, diffby)) From 0800efac77ac25f009239029d051213674ee2e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Tue, 18 Feb 2025 20:10:44 -0500 Subject: [PATCH 20/24] deps: replace dev dep ipdb with trepan3k --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0037a7c..b47b672 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,8 @@ convention = "numpy" [dependency-groups] dev = [ - "ipdb>=0.13.13", "jupyter>=1.1.1", + "trepan3k>=1.3.1", ] test = [ "scikit-learn>=1.3.2", From 305b86ab945764dcce8b4796da8753730d7241d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Wed, 19 Feb 2025 08:11:19 -0500 Subject: [PATCH 21/24] clean(nix); remove unused lines in flake --- flake.nix | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/flake.nix b/flake.nix index 2401890..dc66eca 100644 --- a/flake.nix +++ b/flake.nix @@ -40,10 +40,6 @@ ++ pkgs.lib.optionals pkgs.stdenv.isLinux ( with pkgs; [ - # cudatoolkit - - # This is required for most app that uses graphics api - # linuxPackages.nvidia_x11 ] ); in @@ -54,7 +50,6 @@ let python_with_pkgs = pkgs.python311.withPackages (pp: [ # Add python pkgs here that you need from nix repos - ruff ]); in mkShell { @@ -65,10 +60,8 @@ packages = [ python_with_pkgs python311Packages.venvShellHook - # We # We now recommend to use uv for package management inside nix env mpkgs.uv - # duckdb ] ++ libList; venvDir = "./.venv"; postVenvCreation = '' @@ -89,14 +82,3 @@ } ); } -# Things one might need for debugging or adding compatibility -# export CUDA_PATH=${pkgs.cudaPackages.cudatoolkit} -# export LD_LIBRARY_PATH=${pkgs.cudaPackages.cuda_nvrtc}/lib -# export EXTRA_LDFLAGS="-L/lib -L${pkgs.linuxPackages.nvidia_x11}/lib" -# export EXTRA_CCFLAGS="-I/usr/include" - -# Data syncthing commands -# syncthing cli show system | jq .myID -# syncthing cli config devices add --device-id $DEVICE_ID_B -# syncthing cli config folders $FOLDER_ID devices add --device-id $DEVICE_ID_B -# syncthing cli config devices $DEVICE_ID_A auto-accept-folders set true From c85e5b5e42f45d8f229e0e4c8f59f92e593e3ec6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Wed, 19 Feb 2025 08:22:20 -0500 Subject: [PATCH 22/24] feat: add timing decorator to matching, average_precision and p_val --- src/copairs/compute.py | 3 +++ src/copairs/map/average_precision.py | 12 ++++++------ src/copairs/matching.py | 3 +++ src/copairs/timing.py | 21 +++++++++++++++++++++ 4 files changed, 33 insertions(+), 6 deletions(-) create mode 100644 src/copairs/timing.py diff --git a/src/copairs/compute.py b/src/copairs/compute.py index c03c1b5..7deabe7 100644 --- a/src/copairs/compute.py +++ b/src/copairs/compute.py @@ -9,6 +9,8 @@ import numpy as np from tqdm.autonotebook import tqdm +from copairs.timing import timing + def parallel_map(par_func: Callable[[int], None], items: np.ndarray) -> None: """Execute a function in parallel over a list of items. @@ -319,6 +321,7 @@ def average_precision(rel_k) -> np.ndarray: return ap_values.astype(np.float32) +@timing def ap_contiguous( rel_k_list: np.ndarray, counts: np.ndarray ) -> Tuple[np.ndarray, np.ndarray]: diff --git a/src/copairs/map/average_precision.py b/src/copairs/map/average_precision.py index f1782c5..5408f11 100644 --- a/src/copairs/map/average_precision.py +++ b/src/copairs/map/average_precision.py @@ -9,6 +9,7 @@ from copairs import compute from copairs.matching import UnpairedException, find_pairs +from copairs.timing import timing from .filter import evaluate_and_filter, flatten_str_list, validate_pipeline_input @@ -53,12 +54,10 @@ def build_rank_lists( Array of counts indicating how many times each profile index appears in the rank lists. """ # Combine relevance labels: 1 for positive pairs, 0 for negative pairs - labels = np.concatenate( - [ - np.ones(pos_pairs.size, dtype=np.uint32), - np.zeros(neg_pairs.size, dtype=np.uint32), - ] - ) + labels = np.concatenate([ + np.ones(pos_pairs.size, dtype=np.uint32), + np.zeros(neg_pairs.size, dtype=np.uint32), + ]) # Flatten positive and negative pair indices for ranking ix = np.concatenate([pos_pairs.ravel(), neg_pairs.ravel()]) @@ -222,6 +221,7 @@ def average_precision( return meta +@timing def p_values(dframe: pd.DataFrame, null_size: int, seed: int) -> np.ndarray: """Compute p-values for average precision scores based on a null distribution. diff --git a/src/copairs/matching.py b/src/copairs/matching.py index 032f2bc..19fc88b 100644 --- a/src/copairs/matching.py +++ b/src/copairs/matching.py @@ -12,6 +12,8 @@ import pandas as pd from tqdm.auto import tqdm +from copairs.timing import timing + logger = logging.getLogger("copairs") ColumnList = Union[Sequence[str], pd.Index] ColumnDict = Dict[str, ColumnList] @@ -466,6 +468,7 @@ def filter_fn(x): return {None: list(filter(filter_fn, all_pairs))} +@timing def find_pairs(dframe, sameby, diffby, rev=False) -> np.ndarray: """Find the indices pairs sharing values in `sameby` columns but not on `diffby` columns. diff --git a/src/copairs/timing.py b/src/copairs/timing.py new file mode 100644 index 0000000..52454ce --- /dev/null +++ b/src/copairs/timing.py @@ -0,0 +1,21 @@ +from functools import wraps +from time import time + + +def timing(f): + @wraps(f) + def wrap(*args, **kw): + ts = time() + result = f(*args, **kw) + te = time() + args_to_print = list(args) + if hasattr(args[0], "__iter__"): + args_to_print = (*args[0].shape, *args[1:]) + + print( + "func:%r args:[%r, %r] took: %2.4f sec" + % (f.__name__, args_to_print, kw, te - ts) + ) + return result + + return wrap From 33269af2d4c13832adcc9997a88409abed48cdf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Wed, 19 Feb 2025 11:26:09 -0500 Subject: [PATCH 23/24] feat(matching): add multilabel support --- src/copairs/matching.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/copairs/matching.py b/src/copairs/matching.py index 19fc88b..4cce057 100644 --- a/src/copairs/matching.py +++ b/src/copairs/matching.py @@ -472,7 +472,7 @@ def filter_fn(x): def find_pairs(dframe, sameby, diffby, rev=False) -> np.ndarray: """Find the indices pairs sharing values in `sameby` columns but not on `diffby` columns. - `rev` reverses same and diff, which means that we get the complement + If `rev` is True sameby and diffby are swapped. """ sameby, diffby = _validate(sameby, diffby) @@ -481,6 +481,7 @@ def find_pairs(dframe, sameby, diffby, rev=False) -> np.ndarray: df = dframe.reset_index() with duckdb.connect(":memory:"): + # If rev is True, diffby and sameby are swapped group_1, group_2 = [ [f"{('', 'NOT')[i - rev]} A.{x} = B.{x}" for x in y] for i, y in enumerate((sameby, diffby)) @@ -507,3 +508,22 @@ def _validate(sameby, diffby): raise ValueError("at least one should be provided") return sameby, diffby + + +def find_pairs_multilabel(dframe, sameby, diffby, multilabel_col): + """ + You can include columns with multiple labels (i.e., a list of identifiers). + """ + + nested_col = multilabel_col + "_nested" + indexed = dframe.rename({multilabel_col: nested_col}, axis=1).reset_index() + + # Unnest (i.e., explode) the multilabel column and proceed as normal + with duckdb.connect(":memory:"): + unnested = duckdb.sql( + f"SELECT *,UNNEST({nested_col}) AS {multilabel_col} FROM indexed" + ) + + pairs = get_all_pairs(unnested, sameby, diffby) + + return pairs From 6415ecc4c88dca8f6c56e63b97f4a2def95bca72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Wed, 19 Feb 2025 11:32:34 -0500 Subject: [PATCH 24/24] fringe(find_pairs): condition the reset_index() step to pandas dfs --- src/copairs/matching.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/copairs/matching.py b/src/copairs/matching.py index 4cce057..9d2a22a 100644 --- a/src/copairs/matching.py +++ b/src/copairs/matching.py @@ -469,7 +469,12 @@ def filter_fn(x): @timing -def find_pairs(dframe, sameby, diffby, rev=False) -> np.ndarray: +def find_pairs( + dframe: Union[pd.DataFrame, duckdb.duckdb.DuckDBPyRelation], + sameby: Union[str, ColumnList], + diffby: Union[str, ColumnList], + rev: bool = False, +) -> np.ndarray: """Find the indices pairs sharing values in `sameby` columns but not on `diffby` columns. If `rev` is True sameby and diffby are swapped. @@ -479,7 +484,9 @@ def find_pairs(dframe, sameby, diffby, rev=False) -> np.ndarray: if len(set(sameby).intersection(diffby)): raise ValueError("sameby and diffby must be disjoint lists") - df = dframe.reset_index() + df = dframe + if isinstance(df, pd.DataFrame): + df = dframe.reset_index() with duckdb.connect(":memory:"): # If rev is True, diffby and sameby are swapped group_1, group_2 = [ @@ -524,6 +531,6 @@ def find_pairs_multilabel(dframe, sameby, diffby, multilabel_col): f"SELECT *,UNNEST({nested_col}) AS {multilabel_col} FROM indexed" ) - pairs = get_all_pairs(unnested, sameby, diffby) + pairs = find_pairs(unnested, sameby, diffby) return pairs