From 94d43e236d3c00e690bc4beea5ba0467e77f09fb Mon Sep 17 00:00:00 2001 From: dtrai2 Date: Mon, 9 Dec 2024 13:20:22 +0100 Subject: [PATCH] fix tests --- .../domain_label_extractor/processor.py | 10 ++++----- .../processor/domain_resolver/processor.py | 8 +++++-- logprep/util/url/url.py | 8 ++++--- tests/unit/util/test_url.py | 21 ++++++++++--------- 4 files changed, 27 insertions(+), 20 deletions(-) diff --git a/logprep/processor/domain_label_extractor/processor.py b/logprep/processor/domain_label_extractor/processor.py index 50646470e..3bb40a607 100644 --- a/logprep/processor/domain_label_extractor/processor.py +++ b/logprep/processor/domain_label_extractor/processor.py @@ -96,11 +96,11 @@ def _apply_rules(self, event, rule: DomainLabelExtractorRule): ) return - urlsplit_result = urlsplit(domain) - if urlsplit_result.hostname is not None: - labels = Domain(urlsplit_result.hostname) - else: - labels = Domain(domain) + url = urlsplit(domain) + domain = url.hostname + if url.scheme == "": + domain = url.path + labels = Domain(domain) if labels.suffix != "": fields = { f"{rule.target_field}.registered_domain": f"{labels.domain}.{labels.suffix}", diff --git a/logprep/processor/domain_resolver/processor.py b/logprep/processor/domain_resolver/processor.py index bbefaab09..5872fe08e 100644 --- a/logprep/processor/domain_resolver/processor.py +++ b/logprep/processor/domain_resolver/processor.py @@ -38,6 +38,7 @@ from multiprocessing import context from multiprocessing.pool import ThreadPool from typing import Optional +from urllib.parse import urlsplit from attr import define, field, validators @@ -47,7 +48,6 @@ from logprep.util.cache import Cache from logprep.util.hasher import SHA256Hasher from logprep.util.helper import add_fields_to, get_dotted_field_value -from logprep.util.url.url import Domain logger = logging.getLogger("DomainResolver") @@ -151,7 +151,11 @@ def _apply_rules(self, event, rule): domain_or_url_str = get_dotted_field_value(event, source_field) if not domain_or_url_str: return - domain = Domain(domain_or_url_str).fqdn + + url = urlsplit(domain_or_url_str) + domain = url.hostname + if url.scheme == "": + domain = url.path if not domain: return self.metrics.total_urls += 1 diff --git a/logprep/util/url/url.py b/logprep/util/url/url.py index dd4fa719d..564c6c927 100644 --- a/logprep/util/url/url.py +++ b/logprep/util/url/url.py @@ -144,21 +144,23 @@ def is_valid_scheme(value: str) -> bool: class Domain: """Domain object for easy access to domain parts.""" - def __init__(self, fqdn: str): - self.fqdn = fqdn + def __init__(self, domain_str: str): + self.domain_str = domain_str + self.fqdn = "" self.subdomain = "" self.domain = "" self.suffix = "" self._set_labels() def _set_labels(self): - suffix = self.fqdn + suffix = self.domain_str while suffix != "": _, _, suffix = suffix.partition(".") if suffix in TLD_SET: break self.suffix = suffix if self.suffix != "": + self.fqdn = self.domain_str domain, _, _ = self.fqdn.rpartition(suffix) self.subdomain, _, self.domain = domain.strip(".").rpartition(".") diff --git a/tests/unit/util/test_url.py b/tests/unit/util/test_url.py index 9ad7038a0..c5c501784 100644 --- a/tests/unit/util/test_url.py +++ b/tests/unit/util/test_url.py @@ -75,21 +75,22 @@ def test_extract_urls_with_large_domain_label(self): assert extract_urls(f"http://www.{domain_label}.com") == [] @pytest.mark.parametrize( - "domain, expected_subdomain, expected_domain, expected_suffix", + "domain, expected_subdomain, expected_domain, expected_suffix, expected_fqdn", [ - ("www.thedomain.com", "www", "thedomain", "com"), - ("www.thedomain.co", "www", "thedomain", "co"), - ("www.thedomain.com.ua", "www", "thedomain", "com.ua"), - ("www.thedomain.co.uk", "www", "thedomain", "co.uk"), - ("save.edu.ao", "", "save", "edu.ao"), - ("thedomain.sport", "", "thedomain", "sport"), - ("thedomain.联通", "", "thedomain", "联通"), - ("www.thedomain.foobar", "", "", ""), + ("www.thedomain.com", "www", "thedomain", "com", "www.thedomain.com"), + ("www.thedomain.co", "www", "thedomain", "co", "www.thedomain.co"), + ("www.thedomain.com.ua", "www", "thedomain", "com.ua", "www.thedomain.com.ua"), + ("www.thedomain.co.uk", "www", "thedomain", "co.uk", "www.thedomain.co.uk"), + ("save.edu.ao", "", "save", "edu.ao", "save.edu.ao"), + ("thedomain.sport", "", "thedomain", "sport", "thedomain.sport"), + ("thedomain.联通", "", "thedomain", "联通", "thedomain.联通"), + ("www.thedomain.foobar", "", "", "", ""), ], ) def test_get_labels_from_domain( - self, domain, expected_subdomain, expected_domain, expected_suffix + self, domain, expected_subdomain, expected_domain, expected_suffix, expected_fqdn ): assert Domain(domain).suffix == expected_suffix assert Domain(domain).domain == expected_domain assert Domain(domain).subdomain == expected_subdomain + assert Domain(domain).fqdn == expected_fqdn