From c3fc17e23c07f316f1a882383e8e97406f6e4a7d Mon Sep 17 00:00:00 2001 From: Alicia Bargar Date: Fri, 14 Jun 2024 18:39:08 +0200 Subject: [PATCH] Match functions (#65) * simplifying match functions, adding placeholder for partial text match * add partial string matching, drop duplicate records * update tests * unify pairwise matching * cleanup feature preprocessing, empty sets * simplify matching function to a single function to define * update tests, update threshold * rename function * update typing * update arg names --- modules/matcher.py | 247 +++++++++++++++++++++-------------------- requirements.txt | 3 +- tests/test__matcher.py | 158 ++++++++------------------ 3 files changed, 180 insertions(+), 228 deletions(-) diff --git a/modules/matcher.py b/modules/matcher.py index 61ff12c..952f916 100644 --- a/modules/matcher.py +++ b/modules/matcher.py @@ -2,19 +2,20 @@ import ast import json import logging -import traceback from functools import partial from itertools import chain from pathlib import Path -from typing import Any, Dict, Callable +from typing import Any, Callable import numpy as np import pandas as pd +from rapidfuzz import fuzz from pandas.api.types import is_list_like from modules.indicators import (EMBEDDED_IDS, FINANCIAL_IDS, SOCIAL_MEDIA_IDS, TRACKING_IDS, CRYPTO_IDS) + ## Preprocessing DOMAIN = "domain_name" @@ -24,11 +25,11 @@ MATCH_VALUE = "match_value" -def basic_preprocess(df: pd.DataFrame, feature: str) -> pd.DataFrame: - df = df[[DOMAIN, feature]] - df = df[~df[feature].isna() & ~df[feature].isnull()] +def basic_preprocess(df: pd.DataFrame) -> pd.DataFrame: + df = df[[DOMAIN, INDICATOR]] + df = df[~df[INDICATOR].isna() & ~df[INDICATOR].isnull()] + return df.drop_duplicates() - return df def column_contains_list_string(column: pd.Series) -> bool: # Note: this works off the assumption that all values will have the same type @@ -37,22 +38,30 @@ def column_contains_list_string(column: pd.Series) -> bool: except AttributeError: return False + def column_contains_set_string(column: pd.Series) -> bool: try: return column.iloc[0].startswith("{") except AttributeError: return False + def group_indicators(df: pd.DataFrame) -> pd.Series: if is_list_like(df[INDICATOR].iloc[0]): - return df.groupby(DOMAIN)[INDICATOR].agg(lambda x: set(chain.from_iterable(x))) - elif column_contains_list_string(df[INDICATOR]) or column_contains_set_string(df[INDICATOR]): - df_copy = df.copy() # avoid side effects with ast.literal + result = df.groupby(DOMAIN)[INDICATOR].agg( + lambda x: set(chain.from_iterable(x)) + ) + elif column_contains_list_string(df[INDICATOR]) or column_contains_set_string( + df[INDICATOR] + ): + df_copy = df.copy() # avoid side effects with ast.literal df_copy[INDICATOR] = df_copy[INDICATOR].map(ast.literal_eval) - return df_copy.groupby(DOMAIN)[INDICATOR].agg(lambda x: set(chain.from_iterable(x))) + result = df_copy.groupby(DOMAIN)[INDICATOR].agg( + lambda x: set(chain.from_iterable(x)) + ) else: - return df.groupby(DOMAIN)[INDICATOR].apply(set) - + result = df.groupby(DOMAIN)[INDICATOR].apply(set) + return result[result.str.len() > 0] # whois data @@ -113,85 +122,97 @@ def cert_preprocess(df: pd.DataFrame, cert_feature: str) -> pd.DataFrame: ## Matching -def direct_match( - feature_df: pd.DataFrame, - feature: str, - comparison_df: pd.DataFrame, - indicator=INDICATOR, -) -> pd.DataFrame: - # filter out invalid data - feature_df = basic_preprocess(feature_df, indicator) - comparison_df = basic_preprocess(comparison_df, indicator) - test_matches = pd.merge(feature_df, comparison_df, how="inner", on=indicator) +def direct_match(feature_df: pd.DataFrame, comparison_df: pd.DataFrame) -> pd.DataFrame: + test_matches = pd.merge(feature_df, comparison_df, how="inner", on=INDICATOR) # deduplicating matches = test_matches[test_matches.domain_name_x < test_matches.domain_name_y] - # note this throws a false positive SettingWithCopyWarning; the behavior is OK - matches[MATCH_TYPE] = feature - matches = matches.rename(columns={indicator: MATCH_VALUE}) + matches = matches.rename(columns={INDICATOR: MATCH_VALUE}) return matches.reset_index(drop=True) -# TODO: Add a partial string match function +def match_with_threshold( + feature_series: pd.Series, + comparison_series: pd.Series, + match_function: Callable[[Any, Any], Any], + threshold: int | float +): + match_data = [ + { + "domain_name_x": min(str(f_domain), str(c_domain)), + "domain_name_y": max(str(f_domain), str(c_domain)), + MATCH_VALUE: match_value + } + for f_domain, f_content in feature_series.items() + for c_domain, c_content in comparison_series.items() + if ( + (f_domain != c_domain) + and ( + (match_value := match_function(f_content, c_content)) + >= threshold + ) + ) + ] + # Create DataFrame from string matched data + result = pd.DataFrame( + match_data, columns=["domain_name_x", "domain_name_y", MATCH_VALUE] + ).drop_duplicates(["domain_name_x", "domain_name_y"]) + return result + +def partial_text_match( + feature_df: pd.DataFrame, + comparison_df: pd.DataFrame, + threshold: float = 0.9, +) -> pd.DataFrame: + + def text_similarity_score(x, y): + return fuzz.ratio(x, y) / 100.0 + + feature_series = feature_df.set_index(DOMAIN)[INDICATOR] + comparison_series = comparison_df.set_index(DOMAIN)[INDICATOR] + + return match_with_threshold( + feature_series, + comparison_series, + text_similarity_score, + threshold + ) + def iou_match( feature_df: pd.DataFrame, - feature: str, comparison_df: pd.DataFrame, threshold: float = 0.9, ) -> pd.DataFrame: + # Define IOU function def iou(set1, set2): - return len(set1.intersection(set2)) / (len(set1.union(set2)) + 0.000001) + return round(len(set1.intersection(set2)) / (len(set1.union(set2)) + 0.000001), 3) - # Convert data to sets - feature_sets = group_indicators(feature_df).to_dict() - comparison_sets = group_indicators(comparison_df).to_dict() + feature_series = group_indicators(feature_df) + comparison_series = group_indicators(comparison_df) - # Generate IOU data - iou_data = [ - { - "domain_name_x": f_domain, - "domain_name_y": c_domain, - MATCH_VALUE: round(iou(feature_sets[f_domain], comparison_sets[c_domain]), 3), - "matched_on": feature_sets[f_domain].intersection(comparison_sets[c_domain]) + return match_with_threshold( + feature_series, + comparison_series, + iou, + threshold + ) - } - for f_domain in feature_sets - for c_domain in comparison_sets - if f_domain < c_domain # deduplicate - ] - # Create DataFrame from IOU data - result = pd.DataFrame(iou_data, columns=["domain_name_x", "domain_name_y", "matched_on", MATCH_TYPE, MATCH_VALUE]) - if not result.empty: - result[MATCH_TYPE] = feature - result = result[result[MATCH_VALUE] >= threshold] +def any_in_list_match(feature_df: pd.DataFrame, comparison_df: pd.DataFrame): - return result + def any_in_list(x, y): + return len(x.intersection(y)) -def any_in_list_match( - feature_df: pd.DataFrame, - comparison_df: pd.DataFrame, - feature: str, -): - feature_sets = group_indicators(feature_df).to_dict() - comparison_sets = group_indicators(comparison_df).to_dict() - matches = [ - { - "domain_name_x": f_domain, - "domain_name_y": c_domain, - MATCH_TYPE: feature, - "matched_on": feature_sets[f_domain].intersection(comparison_sets[c_domain]) + feature_series = group_indicators(feature_df) + comparison_series = group_indicators(comparison_df) + + return match_with_threshold( + feature_series, + comparison_series, + any_in_list, + threshold=1 + ) - } - for f_domain in feature_sets - for c_domain in comparison_sets - if f_domain < c_domain # deduplicate - ] - matches_df = pd.DataFrame(matches, columns=["domain_name_x", "domain_name_y", "matched_on", MATCH_TYPE, MATCH_VALUE]) - if not matches_df.empty: - matches_df = matches_df[matches_df["matched_on"].map(lambda d: len(d)) > 0] - matches_df[MATCH_VALUE] = True - return matches_df.reset_index(drop=True) def parse_whois_matches( feature_df: pd.DataFrame, @@ -207,10 +228,9 @@ def parse_whois_matches( whois_feature_comparison_df = feature_df_preprocess(whois_comparison_df, sub_feature) matches = direct_match( whois_feature_df, - feature=sub_feature, - comparison_df=whois_feature_comparison_df, - indicator=sub_feature, + comparison_df=whois_feature_comparison_df ) + matches[MATCH_TYPE] = sub_feature feature_matches.append(matches) whois_matches = pd.concat(feature_matches) return whois_matches @@ -233,9 +253,7 @@ def parse_certificate_matches( ) matches = direct_match( cert_feature_df, - feature=sub_feature, - comparison_df=cert_feature_comparison_df, - indicator=sub_feature, + comparison_df=cert_feature_comparison_df ) feature_matches.append(matches) cert_matches = pd.concat(feature_matches) @@ -243,7 +261,7 @@ def parse_certificate_matches( ## Main program -FEATURE_MATCHING: dict[str, Callable[[pd.DataFrame, str, pd.DataFrame], pd.DataFrame]] = { +FEATURE_MATCHING: dict[str, Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame]] = { "1-cert-domain" : direct_match, "1-crypto-wallet" : direct_match, "1-domain" : direct_match, @@ -299,7 +317,7 @@ def parse_certificate_matches( "2-urlscan_urlssonpage" : iou_match, "2-urlscanhrefs" : iou_match, "2-techstack" : iou_match, -"3-footer-text": direct_match, +"3-footer-text": partial_text_match, "3-outbound-domain": iou_match, "2-ads_txt": iou_match } @@ -325,6 +343,27 @@ def parse_certificate_matches( DICT_FEATURES = {"whois": WHOIS_FEATURES, "certificate": URLSCAN_CERT_FEATURES} +def find_matches_on_feature(data, comparison, feature): + try: + match_func = FEATURE_MATCHING[feature] + except KeyError: + raise KeyError(f"No matching function defined for {feature}") + try: + logging.info(f"Matching {feature} with method: {match_func.__name__}") + except AttributeError as e: + logging.info(f"Matching {feature} with method: {match_func.func.__name__}, {match_func.keywords}, {e}") + # filter to relevant rows + feature_df = data[data[INDICATOR_TYPE] == feature] + comparison_df = comparison[comparison[INDICATOR_TYPE] == feature] + # preprocess + feature_df = basic_preprocess(feature_df) + comparison_df = basic_preprocess(comparison_df) + # match + feature_matches = match_func(feature_df, comparison_df) + feature_matches[MATCH_TYPE] = feature + return feature_matches + + def find_matches(data, comparison=None, result_dir=None) -> pd.DataFrame: matches_per_feature = [] unique_features = data[INDICATOR_TYPE].unique() @@ -333,51 +372,33 @@ def find_matches(data, comparison=None, result_dir=None) -> pd.DataFrame: comparison = data for feature in unique_features: - match_func = FEATURE_MATCHING.get(feature, None) - if not match_func: - logging.error(f"MISSING FEATURE MATCHING METHOD FOR: {feature}") - continue try: - logging.info(f"Matching {feature} with method: {match_func.__name__}") - except AttributeError as e: - logging.info(f"Matching {feature} with method: {match_func.func.__name__}, {match_func.keywords}, {e}") - feature_df = data[data[INDICATOR_TYPE] == feature] - comparison_df = comparison[comparison[INDICATOR_TYPE] == feature] - try: - feature_matches = match_func( - feature_df=feature_df, # type: ignore - feature=feature, - comparison_df=comparison_df - ) + feature_matches = find_matches_on_feature(data, comparison, feature) matches_per_feature.append(feature_matches) if result_dir: feature_matches.to_csv( f"{result_dir}/{feature}_matches.csv", index=False ) except Exception as e: - logging.error(f"Error ({e}) matching feature {feature}: {traceback.print_stack()}") + logging.error(f"Error matching feature {feature}", exc_info=e) continue all_matches = pd.concat(matches_per_feature) return all_matches -def define_output_filename(file1, file2 = None): +def define_output_filename(file1, file2=None): if file2: return f"{Path(file1).stem}_{Path(file2).stem}_results.csv" return f"{Path(file1).stem}_results.csv" - -def main(input_file, compare_file, result_dir, output_file, comparison_type): - if result_dir: - logging.info(f"we'll save intermediary results to the directory {result_dir}") - Path(result_dir).mkdir(exist_ok=True) +def main(input_file, compare_file, output_file, comparison_type): data1 = pd.read_csv(input_file) if comparison_type == "compare" and compare_file: data2 = pd.read_csv(compare_file) - matches: pd.DataFrame = find_matches(data1, data2, result_dir=result_dir) + matches: pd.DataFrame = find_matches(data1, data2) else: - matches = find_matches(data1, result_dir=result_dir) + matches = find_matches(data1) logging.info(f"Matches found: {matches.shape[0]}") logging.info( f"Summary of matches:\n{matches.groupby('match_type')['match_value'].count()}" @@ -395,14 +416,6 @@ def main(input_file, compare_file, result_dir, output_file, comparison_type): "-f", "--input-file", type=str, help="file of indicators to match", default="./indicators_output.csv" ) - parser.add_argument( - "-r", - "--result-dir", - type=str, - help="directory to save intermediary match results", - required=False, - default="./tmp/" - ) parser.add_argument( "-o", "--output-file", @@ -418,22 +431,20 @@ def main(input_file, compare_file, result_dir, output_file, comparison_type): type=str, help="type of comparison to run, pairwise or one-to-one compare", required=False, - default="pairwise" - ) + default="pairwise", + ) parser.add_argument( "-cf", "--compare-file", type=str, help="file of indicators to compare against", required=False, - default="./comparison_indicators.csv" + default="./comparison_indicators.csv", ) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[ - logging.StreamHandler() - ] + handlers=[logging.StreamHandler()], ) args = parser.parse_args() diff --git a/requirements.txt b/requirements.txt index 0606cde..d4a0a4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,5 @@ lxml==5.1.0 bleach==6.1.0 psycopg2-binary==2.9.9 Flask-SQLAlchemy==3.1.1 -flask_migrate==4.0.7 \ No newline at end of file +flask_migrate==4.0.7 +rapidfuzz==3.8.1 \ No newline at end of file diff --git a/tests/test__matcher.py b/tests/test__matcher.py index 9cf01e4..a368601 100644 --- a/tests/test__matcher.py +++ b/tests/test__matcher.py @@ -11,46 +11,45 @@ DOMAIN, INDICATOR, INDICATOR_TYPE, - MATCH_TYPE, MATCH_VALUE, ) def feature_group_as_list_1(): return pd.DataFrame( [ - {DOMAIN: "a", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: [1, 2, 3]}, - {DOMAIN: "b", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: [3, 4, 5]}, - {DOMAIN: "c", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: [4, 5, 6]}, + {DOMAIN: "a", INDICATOR: [1, 2, 3]}, + {DOMAIN: "b", INDICATOR: [3, 4, 5]}, + {DOMAIN: "c", INDICATOR: [4, 5, 6]}, ] ) def feature_group_as_list_str_1(): return pd.DataFrame( [ - {DOMAIN: "a", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "[1, 2, 3]"}, - {DOMAIN: "b", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "[3, 4, 5]"}, - {DOMAIN: "c", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "[4, 5, 6]"}, + {DOMAIN: "a", INDICATOR: "[1, 2, 3]"}, + {DOMAIN: "b", INDICATOR: "[3, 4, 5]"}, + {DOMAIN: "c", INDICATOR: "[4, 5, 6]"}, ] ) def feature_group_as_list_str_2(): return pd.DataFrame( - columns=[DOMAIN, INDICATOR, INDICATOR_TYPE], + columns=[DOMAIN, INDICATOR], data=[ - {DOMAIN: "a", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "[7, 8, 9]"} + {DOMAIN: "a", INDICATOR: "[7, 8, 9]"} ], ) def feature_group_as_string_1(): return pd.DataFrame( [ - {DOMAIN: "a", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "foo"}, - {DOMAIN: "a", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "bar"}, - {DOMAIN: "b", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "bar"}, - {DOMAIN: "b", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "fake"}, - {DOMAIN: "b", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "phrase"}, - {DOMAIN: "c", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "fake"}, - {DOMAIN: "c", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "phrase"}, + {DOMAIN: "a", INDICATOR: "foo"}, + {DOMAIN: "a", INDICATOR: "bar"}, + {DOMAIN: "b", INDICATOR: "bar"}, + {DOMAIN: "b", INDICATOR: "fake"}, + {DOMAIN: "b", INDICATOR: "phrase"}, + {DOMAIN: "c", INDICATOR: "fake"}, + {DOMAIN: "c", INDICATOR: "phrase"}, ] ) @@ -60,14 +59,14 @@ def feature_group_as_string_1(): ( pd.DataFrame( [ - {DOMAIN: "foo", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "abc"}, - {DOMAIN: "bar", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "123"}, + {DOMAIN: "foo", INDICATOR: "abc"}, + {DOMAIN: "bar", INDICATOR: "123"}, ] ), pd.DataFrame( [ - {DOMAIN: "foo", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "abc"}, - {DOMAIN: "foo2", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "abc"}, + {DOMAIN: "foo", INDICATOR: "abc"}, + {DOMAIN: "foo2", INDICATOR: "abc"}, ] ), pd.DataFrame( @@ -76,7 +75,6 @@ def feature_group_as_string_1(): "domain_name_x": "foo", MATCH_VALUE: "abc", "domain_name_y": "foo2", - MATCH_TYPE: "feature", }, ] ), @@ -84,21 +82,21 @@ def feature_group_as_string_1(): ( pd.DataFrame( [ - {DOMAIN: "foo", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "abc"}, - {DOMAIN: "bar", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "123"}, + {DOMAIN: "foo", INDICATOR: "abc"}, + {DOMAIN: "bar", INDICATOR: "123"}, ] ), pd.DataFrame( - [{DOMAIN: "foo2", INDICATOR_TYPE: INDICATOR_TYPE, INDICATOR: "nope"}] + [{DOMAIN: "foo2", INDICATOR: "nope"}] ), pd.DataFrame( - [], columns=["domain_name_x", MATCH_VALUE, "domain_name_y", MATCH_TYPE] + [], columns=["domain_name_x", MATCH_VALUE, "domain_name_y"] ), ), ], ) def test__direct_match(feature_df, compare_df, expected_results): - matches = direct_match(feature_df, "feature", compare_df) + matches = direct_match(feature_df, compare_df) pd.testing.assert_frame_equal(matches, expected_results, check_index_type=False) @@ -110,22 +108,9 @@ def test__direct_match(feature_df, compare_df, expected_results): feature_group_as_list_str_1(), pd.DataFrame( [ - { - "domain_name_x": "a", - "domain_name_y": "b", - "match_type": "feature", - "match_value": 0.2, - }, - { - "domain_name_x": "a", - "domain_name_y": "c", - "match_type": "feature", - "match_value": 0.0, - }, { "domain_name_x": "b", "domain_name_y": "c", - "match_type": "feature", "match_value": 0.5, }, ] @@ -136,22 +121,9 @@ def test__direct_match(feature_df, compare_df, expected_results): feature_group_as_list_str_1(), pd.DataFrame( [ - { - "domain_name_x": "a", - "domain_name_y": "b", - "match_type": "feature", - "match_value": 0.2, - }, - { - "domain_name_x": "a", - "domain_name_y": "c", - "match_type": "feature", - "match_value": 0.0, - }, { "domain_name_x": "b", "domain_name_y": "c", - "match_type": "feature", "match_value": 0.5, }, ] @@ -162,41 +134,27 @@ def test__direct_match(feature_df, compare_df, expected_results): feature_group_as_string_1(), pd.DataFrame( [ - { - "domain_name_x": "a", - "domain_name_y": "b", - "match_type": "feature", - "match_value": 0.25, - }, - { - "domain_name_x": "a", - "domain_name_y": "c", - "match_type": "feature", - "match_value": 0.0, - }, { "domain_name_x": "b", "domain_name_y": "c", - "match_type": "feature", "match_value": 0.667, }, ] ), - id="two set-like strings, same values"), + id="two set-like strings, different values"), pytest.param( feature_group_as_list_str_1(), feature_group_as_list_str_2(), pd.DataFrame( - columns=["domain_name_x", "domain_name_y", "match_type", "match_value"] + columns=["domain_name_x", "domain_name_y", "match_value"] ), id="two listlike strings, different values"), ], ) def test__iou_match(feature_df, compare_df, expected_results): results = iou_match( - feature_df=feature_df, comparison_df=compare_df, feature="feature", threshold=0 + feature_df=feature_df, comparison_df=compare_df, threshold=0.5 ) - results = results.drop("matched_on", axis=1) # can't compare equality of sets pd.testing.assert_frame_equal(results, expected_results, check_index_type=False) @@ -218,16 +176,12 @@ def test__parse_certificate_matches(): { "domain_name_x": "a", "domain_name_y": "b", - "match_type": "feature", - "match_value": True, - # "matched_on": {3}, + "match_value": 1, }, { "domain_name_x": "b", "domain_name_y": "c", - "match_type": "feature", - "match_value": True, - # "matched_on" : {4, 5}, + "match_value": 2, }, ] ) @@ -240,16 +194,12 @@ def test__parse_certificate_matches(): { "domain_name_x": "a", "domain_name_y": "b", - "match_type": "feature", - "match_value": True, - # "matched_on": {'bar'}, + "match_value": 1, }, { "domain_name_x": "b", "domain_name_y": "c", - "match_type": "feature", - "match_value": True, - # "matched_on" : {'phrase', 'fake'}, + "match_value": 2, }, ] ) @@ -258,16 +208,15 @@ def test__parse_certificate_matches(): feature_group_as_list_str_1(), feature_group_as_list_str_2(), pd.DataFrame( - columns=["domain_name_x", "domain_name_y", "match_type", "match_value"], + columns=["domain_name_x", "domain_name_y", "match_value"], data=[] ) ), ] ) def test__any_in_list_match(feature_df, compare_df, expected_results): - results = any_in_list_match(feature_df, compare_df, feature='feature') - results = results.drop("matched_on", axis=1) - pd.testing.assert_frame_equal(results, expected_results, check_index_type=False) + results = any_in_list_match(feature_df, compare_df) + pd.testing.assert_frame_equal(results.reset_index(drop=True), expected_results.reset_index(drop=True), check_index_type=False) def test__dict_direct_match(): @@ -283,7 +232,7 @@ def test__abs_difference_vs_threshold(): @pytest.mark.parametrize( - "data,comparison,result_dir", + "data,comparison", [ ( pd.DataFrame( @@ -296,7 +245,6 @@ def test__abs_difference_vs_threshold(): columns=[DOMAIN, INDICATOR_TYPE, INDICATOR], ), None, - None, ), ( pd.DataFrame( @@ -317,25 +265,24 @@ def test__abs_difference_vs_threshold(): ], columns=[DOMAIN, INDICATOR_TYPE, INDICATOR], ), - None, ), ], ) -def test__find_matches(data, comparison, result_dir): - find_matches(data, comparison, result_dir) +def test__find_matches(data, comparison): + find_matches(data, comparison) @pytest.mark.parametrize( - "input_file,compare_file,comparison_type,result_dir,output_file", + "input_file,compare_file,comparison_type,output_file", [ - ("i_file", "c_file", "compare", "r_dir", "r_file"), - ("i_file", "c_file", "pairwise", "r_dir", "r_file"), - ("i_file", "c_file", "compare", "r_dir", None), - ("i_file", "c_file", "pairwise", "r_dir", None), - ("i_file", "c_file", "compare", None, "r_file"), - ("i_file", "c_file", "pairwise", None, "r_file"), - ("i_file", None, "compare", "r_dir", "r_file"), - ("i_file", None, "pairwise", "r_dir", "r_file"), + ("i_file", "c_file", "compare", "r_file"), + ("i_file", "c_file", "pairwise", "r_file"), + ("i_file", "c_file", "compare", None), + ("i_file", "c_file", "pairwise", None), + ("i_file", "c_file", "compare", "r_file"), + ("i_file", "c_file", "pairwise", "r_file"), + ("i_file", None, "compare", "r_file"), + ("i_file", None, "pairwise", "r_file"), ], ) @mock.patch("modules.matcher.find_matches") @@ -350,24 +297,17 @@ def test__main( input_file, compare_file, comparison_type, - result_dir, output_file, ): mock_define_output_filename.return_value = "r_file" mock_read_csv.return_value = "fake_data" - main(input_file, compare_file, result_dir, output_file, comparison_type) - if result_dir: - mock_mkdir.assert_called_once_with(exist_ok=True) - else: - mock_mkdir.assert_not_called() + main(input_file, compare_file, output_file, comparison_type) if not output_file: mock_define_output_filename.assert_called_once_with(input_file, compare_file) output_file = "r_file" else: mock_define_output_filename.assert_not_called() if comparison_type == "compare" and compare_file: - mock_find_matches.assert_called_with( - "fake_data", "fake_data", result_dir=result_dir - ) + mock_find_matches.assert_called_with("fake_data", "fake_data" ) else: - mock_find_matches.assert_called_once_with("fake_data", result_dir=result_dir) + mock_find_matches.assert_called_once_with("fake_data")