Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multilabel support to duckdb implementation #82

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d6b479f
deps: add duckdb and nix stuff
afermg Feb 14, 2025
76e2d11
feat: add temporary dev file
afermg Feb 14, 2025
c8380d8
dev: clean up and add test
afermg Feb 14, 2025
f68c34c
index on duckdb: c8380d8 dev: clean up and add test
afermg Feb 14, 2025
8b492ec
On duckdb: c8380d8 dev: clean up and add test
afermg Feb 14, 2025
f9af6fc
index on (no branch): 8b492ec On duckdb: c8380d8 dev: clean up and ad…
afermg Feb 14, 2025
517045b
deps: add ipdb and jupyter as dev deps.
afermg Feb 14, 2025
b7e296e
fix(find_pairs): The new function passes tests and goes brrr
afermg Feb 14, 2025
71420cb
clean: remove dev.py
afermg Feb 14, 2025
23874c7
clean: remove commented code on average_precision
afermg Feb 14, 2025
6114b17
style: ruff format
afermg Feb 14, 2025
1dcb0dc
fix(py38): remove typing that breaks CI checks for py38
afermg Feb 14, 2025
b7a4673
clean: delete sneaky leftover comments
afermg Feb 14, 2025
c237287
fix: cover arg errors inside find_pairs
afermg Feb 15, 2025
0812f17
tests: replace most Matcher tests with find_pairs
afermg Feb 15, 2025
c9365f7
feat(find_pairs): support set complement
afermg Feb 15, 2025
011baa8
tests: adjust test to replace Matcher with find_pairs
afermg Feb 15, 2025
b700463
fix(test): sample all the values
afermg Feb 15, 2025
b79f798
clean: (finally) remove matcher from test_matching
afermg Feb 15, 2025
b34158a
change(matching): replace local database for in-memory
afermg Feb 19, 2025
0800efa
deps: replace dev dep ipdb with trepan3k
afermg Feb 19, 2025
305b86a
clean(nix); remove unused lines in flake
afermg Feb 19, 2025
c85e5b5
feat: add timing decorator to matching, average_precision and p_val
afermg Feb 19, 2025
33269af
feat(matching): add multilabel support
afermg Feb 19, 2025
6415ecc
fringe(find_pairs): condition the reset_index() step to pandas dfs
afermg Feb 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

84 changes: 84 additions & 0 deletions flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
{
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-24.11";
nixpkgs_master.url = "github:NixOS/nixpkgs/master";
systems.url = "github:nix-systems/default";
flake-utils.url = "github:numtide/flake-utils";
flake-utils.inputs.systems.follows = "systems";
};

outputs =
{
self,
nixpkgs,
flake-utils,
...
}@inputs:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = import nixpkgs {
inherit system;
config.allowUnfree = true;
config.cudaSupport = true;
};

mpkgs = import inputs.nixpkgs_master {
inherit system;
config.allowUnfree = true;
config.cudaSupport = true;
};

libList =
[
# Add needed packages here
pkgs.stdenv.cc.cc
pkgs.libGL
pkgs.glib
pkgs.zlib
]
++ pkgs.lib.optionals pkgs.stdenv.isLinux (
with pkgs;
[
]
);
in
with pkgs;
{
devShells = {
default =
let
python_with_pkgs = pkgs.python311.withPackages (pp: [
# Add python pkgs here that you need from nix repos
]);
in
mkShell {
NIX_LD = runCommand "ld.so" { } ''
ln -s "$(cat '${pkgs.stdenv.cc}/nix-support/dynamic-linker')" $out
'';
NIX_LD_LIBRARY_PATH = lib.makeLibraryPath libList;
packages = [
python_with_pkgs
python311Packages.venvShellHook
mpkgs.uv

] ++ libList;
venvDir = "./.venv";
postVenvCreation = ''
unset SOURCE_DATE_EPOCH
'';
postShellHook = ''
unset SOURCE_DATE_EPOCH
'';
shellHook = ''
export LD_LIBRARY_PATH=$NIX_LD_LIBRARY_PATH:$LD_LIBRARY_PATH
export PYTHON_KEYRING_BACKEND=keyring.backends.fail.Keyring
runHook venvShellHook
uv pip install -e .
export PYTHONPATH=${python_with_pkgs}/${python_with_pkgs.sitePackages}:$PYTHONPATH
'';
};
};
}
);
}
18 changes: 16 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,16 @@ authors = [
dependencies = [
"pandas",
"tqdm",
"statsmodels"
"statsmodels",
"duckdb>=1.2.0",
]

[project.optional-dependencies]
dev = ["ruff"]
dev = [
"ipdb>=0.13.13",
"jupyter>=1.1.1",
"ruff",
]
plot = ["plotly"]
test = ["scikit-learn", "pytest"]
demo = ["notebook", "matplotlib"]
Expand All @@ -39,3 +44,12 @@ select = ["D"]

[tool.ruff.lint.pydocstyle]
convention = "numpy"

[dependency-groups]
dev = [
"jupyter>=1.1.1",
"trepan3k>=1.3.1",
]
test = [
"scikit-learn>=1.3.2",
]
3 changes: 3 additions & 0 deletions src/copairs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import numpy as np
from tqdm.autonotebook import tqdm

from copairs.timing import timing


def parallel_map(par_func: Callable[[int], None], items: np.ndarray) -> None:
"""Execute a function in parallel over a list of items.
Expand Down Expand Up @@ -319,6 +321,7 @@ def average_precision(rel_k) -> np.ndarray:
return ap_values.astype(np.float32)


@timing
def ap_contiguous(
rel_k_list: np.ndarray, counts: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
Expand Down
40 changes: 11 additions & 29 deletions src/copairs/map/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pandas as pd

from copairs import compute
from copairs.matching import Matcher, UnpairedException
from copairs.matching import UnpairedException, find_pairs
from copairs.timing import timing

from .filter import evaluate_and_filter, flatten_str_list, validate_pipeline_input

Expand Down Expand Up @@ -53,12 +54,10 @@ def build_rank_lists(
Array of counts indicating how many times each profile index appears in the rank lists.
"""
# Combine relevance labels: 1 for positive pairs, 0 for negative pairs
labels = np.concatenate(
[
np.ones(pos_pairs.size, dtype=np.uint32),
np.zeros(neg_pairs.size, dtype=np.uint32),
]
)
labels = np.concatenate([
np.ones(pos_pairs.size, dtype=np.uint32),
np.zeros(neg_pairs.size, dtype=np.uint32),
])

# Flatten positive and negative pair indices for ranking
ix = np.concatenate([pos_pairs.ravel(), neg_pairs.ravel()])
Expand Down Expand Up @@ -178,38 +177,20 @@ def average_precision(
# Reset metadata index for consistent indexing
meta = meta.reset_index(drop=True).copy()

# Initialize the Matcher object to find pairs based on metadata rules
logger.info("Indexing metadata...")
matcher = Matcher(meta, columns, seed=0)

# Identify positive pairs based on `pos_sameby` and `pos_diffby`
logger.info("Finding positive pairs...")
pos_pairs = matcher.get_all_pairs(sameby=pos_sameby, diffby=pos_diffby)
pos_total = sum(len(p) for p in pos_pairs.values())
if pos_total == 0:
pos_pairs = find_pairs(meta, sameby=pos_sameby, diffby=pos_diffby)
if len(pos_pairs) == 0:
raise UnpairedException("Unable to find positive pairs.")

# Convert positive pairs to a NumPy array for efficient computation
pos_pairs = np.fromiter(
itertools.chain.from_iterable(pos_pairs.values()),
dtype=np.dtype((np.uint32, 2)),
count=pos_total,
)

# Identify negative pairs based on `neg_sameby` and `neg_diffby`
logger.info("Finding negative pairs...")
neg_pairs = matcher.get_all_pairs(sameby=neg_sameby, diffby=neg_diffby)
neg_total = sum(len(p) for p in neg_pairs.values())
if neg_total == 0:
neg_pairs = find_pairs(meta, sameby=neg_sameby, diffby=neg_diffby)
if len(neg_pairs) == 0:
raise UnpairedException("Unable to find negative pairs.")

# Convert negative pairs to a NumPy array for efficient computation
neg_pairs = np.fromiter(
itertools.chain.from_iterable(neg_pairs.values()),
dtype=np.dtype((np.uint32, 2)),
count=neg_total,
)

# Compute similarities for positive pairs
logger.info("Computing positive similarities...")
pos_sims = distance_fn(feats, pos_pairs, batch_size)
Expand Down Expand Up @@ -240,6 +221,7 @@ def average_precision(
return meta


@timing
def p_values(dframe: pd.DataFrame, null_size: int, seed: int) -> np.ndarray:
"""Compute p-values for average precision scores based on a null distribution.

Expand Down
71 changes: 71 additions & 0 deletions src/copairs/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from math import comb
from typing import Dict, Sequence, Set, Union

import duckdb
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from copairs.timing import timing

logger = logging.getLogger("copairs")
ColumnList = Union[Sequence[str], pd.Index]
ColumnDict = Dict[str, ColumnList]
Expand Down Expand Up @@ -463,3 +466,71 @@
return set(x) not in pairs

return {None: list(filter(filter_fn, all_pairs))}


@timing
def find_pairs(
dframe: Union[pd.DataFrame, duckdb.duckdb.DuckDBPyRelation],
sameby: Union[str, ColumnList],
diffby: Union[str, ColumnList],
rev: bool = False,
) -> np.ndarray:
"""Find the indices pairs sharing values in `sameby` columns but not on `diffby` columns.

If `rev` is True sameby and diffby are swapped.
"""
sameby, diffby = _validate(sameby, diffby)

if len(set(sameby).intersection(diffby)):
raise ValueError("sameby and diffby must be disjoint lists")

df = dframe
if isinstance(df, pd.DataFrame):
df = dframe.reset_index()
with duckdb.connect(":memory:"):
# If rev is True, diffby and sameby are swapped
group_1, group_2 = [
[f"{('', 'NOT')[i - rev]} A.{x} = B.{x}" for x in y]
for i, y in enumerate((sameby, diffby))
]
string = (
f"SELECT A.index,B.index"
" FROM df A"
" JOIN df B"
" ON A.index < B.index" # Ensures only one of (a,b)/(b,a) and no (a,a)
f" AND {' AND '.join((*group_1, *group_2))}"
)
index_d = duckdb.sql(string).fetchnumpy()

return np.array((index_d["index"], index_d["index_1"]), dtype=np.uint32).T


def _validate(sameby, diffby):
if isinstance(sameby, str):
sameby = (sameby,)
if isinstance(diffby, str):
sameby = (diffby,)

if not (len(sameby) or len(diffby)):
raise ValueError("at least one should be provided")

return sameby, diffby


def find_pairs_multilabel(dframe, sameby, diffby, multilabel_col):
"""
You can include columns with multiple labels (i.e., a list of identifiers).
"""

Check failure on line 523 in src/copairs/matching.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D200)

src/copairs/matching.py:521:5: D200 One-line docstring should fit on one line

Check failure on line 523 in src/copairs/matching.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D202)

src/copairs/matching.py:521:5: D202 No blank lines allowed after function docstring (found 1)

nested_col = multilabel_col + "_nested"
indexed = dframe.rename({multilabel_col: nested_col}, axis=1).reset_index()

# Unnest (i.e., explode) the multilabel column and proceed as normal
with duckdb.connect(":memory:"):
unnested = duckdb.sql(
f"SELECT *,UNNEST({nested_col}) AS {multilabel_col} FROM indexed"
)

pairs = find_pairs(unnested, sameby, diffby)

return pairs
Loading
Loading