From 7214cbc0285c79590b3cd2e64bbe0286cf2496a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Al=C3=A1n=20F=2E=20Mu=C3=B1oz?= Date: Wed, 19 Feb 2025 14:51:34 -0500 Subject: [PATCH] style: document and format find_pairs_multilabel --- src/copairs/matching.py | 48 +++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/src/copairs/matching.py b/src/copairs/matching.py index ae24686..120de5a 100644 --- a/src/copairs/matching.py +++ b/src/copairs/matching.py @@ -523,15 +523,35 @@ def find_pairs_multilabel( sameby: Union[str, ColumnList], diffby: Union[str, ColumnList], multilabel_col: str, -): +) -> np.ndarray: """ - You can include columns with multiple labels (i.e., a list of identifiers). + Find pairs of rows in a DataFrame that have the same or different values in certain columns. + + The function takes into account columns with multiple labels (i.e., a list of identifiers). + + Parameters + ---------- + dframe : Union[pd.DataFrame, duckdb.duckdb.DuckDBPyRelation] + Input DataFrame. + sameby : Union[str, ColumnList] + List of column names to consider for finding identical values. + diffby : Union[str, ColumnList] + List of column names to consider for finding different values. + multilabel_col : str + Name of the column containing multiple labels. + + Returns + ------- + np.ndarray + Array of pairs of indices with matching or non-matching values in the specified columns. + + Notes + ----- + The function asserts that `multilabel_col` is present in either `sameby` or `diffby`. """ - assert (multilabel_col in sameby) or (multilabel_col in diffby), f"Missing {multilabel_col} in sameby and diffby" - nested_col = multilabel_col + "_nested" - indexed = dframe.rename({multilabel_col: nested_col}, axis=1).reset_index() + df = dframe.reset_index() if multilabel_col in sameby: sameby.remove(multilabel_col) @@ -541,14 +561,24 @@ def find_pairs_multilabel( shared_item = False with duckdb.connect(":memory:"): - result = duckdb.sql(f"SELECT * FROM (select *,CAST(len(list_intersect(A.{nested_col},B.{nested_col})) AS BOOL) AS shared_item FROM indexed A JOIN indexed B ON A.index < B.index) WHERE shared_item = {shared_item}") + result = duckdb.sql( + "SELECT * " + f" FROM (select *,CAST(len(list_intersect(A.{multilabel_col},B.{multilabel_col})) AS BOOL)" + " AS shared_item " + " FROM df A JOIN df B ON A.index < B.index)" + " WHERE shared_item = {shared_item}" + ) if len(sameby) or len(diffby): - monolabel_result = find_pairs(indexed, sameby, diffby).T - result = duckdb.sql(f"SELECT index, index_1 FROM result A JOIN monolabel_result B ON A.index = B.column0 AND A.index_1 = B.column1") + monolabel_result = find_pairs(df, sameby, diffby).T + result = duckdb.sql( + f"SELECT index, index_1" + " FROM result A JOIN monolabel_result B" + " ON A.index = B.column0 " + "AND A.index_1 = B.column1" + ) index_d = result.fetchnumpy() result = np.array((index_d["index"], index_d["index_1"]), dtype=np.uint32).T return result -