Skip to content

Commit

Permalink
Merge branch 'timing' into duckdb
Browse files Browse the repository at this point in the history
  • Loading branch information
afermg committed Feb 19, 2025
2 parents f558e2e + 7214cbc commit 513c059
Showing 1 changed file with 39 additions and 9 deletions.
48 changes: 39 additions & 9 deletions src/copairs/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,15 +523,35 @@ def find_pairs_multilabel(
sameby: Union[str, ColumnList],
diffby: Union[str, ColumnList],
multilabel_col: str,
):
) -> np.ndarray:
"""
You can include columns with multiple labels (i.e., a list of identifiers).
Find pairs of rows in a DataFrame that have the same or different values in certain columns.
The function takes into account columns with multiple labels (i.e., a list of identifiers).
Parameters
----------
dframe : Union[pd.DataFrame, duckdb.duckdb.DuckDBPyRelation]
Input DataFrame.
sameby : Union[str, ColumnList]
List of column names to consider for finding identical values.
diffby : Union[str, ColumnList]
List of column names to consider for finding different values.
multilabel_col : str
Name of the column containing multiple labels.
Returns
-------
np.ndarray
Array of pairs of indices with matching or non-matching values in the specified columns.
Notes
-----
The function asserts that `multilabel_col` is present in either `sameby` or `diffby`.
"""

assert (multilabel_col in sameby) or (multilabel_col in diffby), f"Missing {multilabel_col} in sameby and diffby"

nested_col = multilabel_col + "_nested"
indexed = dframe.rename({multilabel_col: nested_col}, axis=1).reset_index()
df = dframe.reset_index()

if multilabel_col in sameby:
sameby.remove(multilabel_col)
Expand All @@ -541,14 +561,24 @@ def find_pairs_multilabel(
shared_item = False

with duckdb.connect(":memory:"):
result = duckdb.sql(f"SELECT * FROM (select *,CAST(len(list_intersect(A.{nested_col},B.{nested_col})) AS BOOL) AS shared_item FROM indexed A JOIN indexed B ON A.index < B.index) WHERE shared_item = {shared_item}")
result = duckdb.sql(
"SELECT * "
f" FROM (select *,CAST(len(list_intersect(A.{multilabel_col},B.{multilabel_col})) AS BOOL)"
" AS shared_item "
" FROM df A JOIN df B ON A.index < B.index)"
" WHERE shared_item = {shared_item}"
)

if len(sameby) or len(diffby):
monolabel_result = find_pairs(indexed, sameby, diffby).T
result = duckdb.sql(f"SELECT index, index_1 FROM result A JOIN monolabel_result B ON A.index = B.column0 AND A.index_1 = B.column1")
monolabel_result = find_pairs(df, sameby, diffby).T
result = duckdb.sql(
f"SELECT index, index_1"
" FROM result A JOIN monolabel_result B"
" ON A.index = B.column0 "
"AND A.index_1 = B.column1"
)

index_d = result.fetchnumpy()
result = np.array((index_d["index"], index_d["index_1"]), dtype=np.uint32).T

return result

0 comments on commit 513c059

Please sign in to comment.