Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrai2 committed Dec 9, 2024
1 parent 8af6cb7 commit 94d43e2
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 20 deletions.
10 changes: 5 additions & 5 deletions logprep/processor/domain_label_extractor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def _apply_rules(self, event, rule: DomainLabelExtractorRule):
)
return

urlsplit_result = urlsplit(domain)
if urlsplit_result.hostname is not None:
labels = Domain(urlsplit_result.hostname)
else:
labels = Domain(domain)
url = urlsplit(domain)
domain = url.hostname
if url.scheme == "":
domain = url.path
labels = Domain(domain)
if labels.suffix != "":
fields = {
f"{rule.target_field}.registered_domain": f"{labels.domain}.{labels.suffix}",
Expand Down
8 changes: 6 additions & 2 deletions logprep/processor/domain_resolver/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from multiprocessing import context
from multiprocessing.pool import ThreadPool
from typing import Optional
from urllib.parse import urlsplit

from attr import define, field, validators

Expand All @@ -47,7 +48,6 @@
from logprep.util.cache import Cache
from logprep.util.hasher import SHA256Hasher
from logprep.util.helper import add_fields_to, get_dotted_field_value
from logprep.util.url.url import Domain

logger = logging.getLogger("DomainResolver")

Expand Down Expand Up @@ -151,7 +151,11 @@ def _apply_rules(self, event, rule):
domain_or_url_str = get_dotted_field_value(event, source_field)
if not domain_or_url_str:
return
domain = Domain(domain_or_url_str).fqdn

url = urlsplit(domain_or_url_str)
domain = url.hostname
if url.scheme == "":
domain = url.path
if not domain:
return
self.metrics.total_urls += 1
Expand Down
8 changes: 5 additions & 3 deletions logprep/util/url/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,23 @@ def is_valid_scheme(value: str) -> bool:
class Domain:
"""Domain object for easy access to domain parts."""

def __init__(self, fqdn: str):
self.fqdn = fqdn
def __init__(self, domain_str: str):
self.domain_str = domain_str
self.fqdn = ""
self.subdomain = ""
self.domain = ""
self.suffix = ""
self._set_labels()

def _set_labels(self):
suffix = self.fqdn
suffix = self.domain_str
while suffix != "":
_, _, suffix = suffix.partition(".")
if suffix in TLD_SET:
break
self.suffix = suffix
if self.suffix != "":
self.fqdn = self.domain_str
domain, _, _ = self.fqdn.rpartition(suffix)
self.subdomain, _, self.domain = domain.strip(".").rpartition(".")

Expand Down
21 changes: 11 additions & 10 deletions tests/unit/util/test_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,22 @@ def test_extract_urls_with_large_domain_label(self):
assert extract_urls(f"http://www.{domain_label}.com") == []

@pytest.mark.parametrize(
"domain, expected_subdomain, expected_domain, expected_suffix",
"domain, expected_subdomain, expected_domain, expected_suffix, expected_fqdn",
[
("www.thedomain.com", "www", "thedomain", "com"),
("www.thedomain.co", "www", "thedomain", "co"),
("www.thedomain.com.ua", "www", "thedomain", "com.ua"),
("www.thedomain.co.uk", "www", "thedomain", "co.uk"),
("save.edu.ao", "", "save", "edu.ao"),
("thedomain.sport", "", "thedomain", "sport"),
("thedomain.联通", "", "thedomain", "联通"),
("www.thedomain.foobar", "", "", ""),
("www.thedomain.com", "www", "thedomain", "com", "www.thedomain.com"),
("www.thedomain.co", "www", "thedomain", "co", "www.thedomain.co"),
("www.thedomain.com.ua", "www", "thedomain", "com.ua", "www.thedomain.com.ua"),
("www.thedomain.co.uk", "www", "thedomain", "co.uk", "www.thedomain.co.uk"),
("save.edu.ao", "", "save", "edu.ao", "save.edu.ao"),
("thedomain.sport", "", "thedomain", "sport", "thedomain.sport"),
("thedomain.联通", "", "thedomain", "联通", "thedomain.联通"),
("www.thedomain.foobar", "", "", "", ""),
],
)
def test_get_labels_from_domain(
self, domain, expected_subdomain, expected_domain, expected_suffix
self, domain, expected_subdomain, expected_domain, expected_suffix, expected_fqdn
):
assert Domain(domain).suffix == expected_suffix
assert Domain(domain).domain == expected_domain
assert Domain(domain).subdomain == expected_subdomain
assert Domain(domain).fqdn == expected_fqdn

0 comments on commit 94d43e2

Please sign in to comment.