Skip to content

Commit

Permalink
refactor: pydantic config settings
Browse files Browse the repository at this point in the history
- Create additional custom type to simplify model validators
- Split customs into `types`, `validators` and `converters`
  • Loading branch information
NTFSvolume committed Jan 27, 2025
1 parent 722b42a commit 1c5b19b
Show file tree
Hide file tree
Showing 8 changed files with 294 additions and 280 deletions.
16 changes: 8 additions & 8 deletions cyberdrop_dl/config_definitions/authentication_settings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pydantic import BaseModel, Field

from .pydantic.custom_types import AliasModel
from .custom.types import AliasModel


class ForumAuth(BaseModel):
Expand Down Expand Up @@ -74,10 +74,10 @@ class RealDebridAuth(AliasModel):
class AuthSettings(AliasModel):
coomer: CoomerAuth = Field(validation_alias="Coomer", default=CoomerAuth())
forums: ForumAuth = Field(validation_alias="Forums", default=ForumAuth())
gofile: GoFileAuth = Field(validation_alias="GoFile", default=GoFileAuth())
imgur: ImgurAuth = Field(validation_alias="Imgur", default=ImgurAuth())
jdownloader: JDownloaderAuth = Field(validation_alias="JDownloader", default=JDownloaderAuth())
pixeldrain: PixeldrainAuth = Field(validation_alias="PixelDrain", default=PixeldrainAuth())
realdebrid: RealDebridAuth = Field(validation_alias="RealDebrid", default=RealDebridAuth())
reddit: RedditAuth = Field(validation_alias="Reddit", default=RedditAuth())
xxxbunker: XXXBunkerAuth = Field(validation_alias="XXXBunker", default=XXXBunkerAuth())
gofile: GoFileAuth = Field(validation_alias="GoFile", default=GoFileAuth()) # type: ignore
imgur: ImgurAuth = Field(validation_alias="Imgur", default=ImgurAuth()) # type: ignore
jdownloader: JDownloaderAuth = Field(validation_alias="JDownloader", default=JDownloaderAuth()) # type: ignore
pixeldrain: PixeldrainAuth = Field(validation_alias="PixelDrain", default=PixeldrainAuth()) # type: ignore
realdebrid: RealDebridAuth = Field(validation_alias="RealDebrid", default=RealDebridAuth()) # type: ignore
reddit: RedditAuth = Field(validation_alias="Reddit", default=RedditAuth()) # type: ignore
xxxbunker: XXXBunkerAuth = Field(validation_alias="XXXBunker", default=XXXBunkerAuth()) # type: ignore
151 changes: 44 additions & 107 deletions cyberdrop_dl/config_definitions/config_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,25 @@
from pathlib import Path
from typing import Literal

from pydantic import BaseModel, ByteSize, Field, NonNegativeInt, PositiveInt, field_serializer, field_validator
from pydantic import BaseModel, ByteSize, Field, NonNegativeInt, PositiveInt, field_validator

from cyberdrop_dl.config_definitions.pydantic.validators import parse_duration_to_timedelta
from cyberdrop_dl.utils.constants import APP_STORAGE, BROWSERS, DOWNLOAD_STORAGE
from cyberdrop_dl.utils.data_enums_classes.hash import Hashing
from cyberdrop_dl.utils.data_enums_classes.supported_domains import SUPPORTED_SITES_DOMAINS

from .pydantic.custom_types import AliasModel, HttpAppriseURL, NonEmptyStr
from .custom.types import (
AliasModel,
ByteSizeSerilized,
HttpAppriseURL,
ListNonEmptyStr,
ListNonNegativeInt,
LogPath,
MainLogPath,
NonEmptyStr,
NonEmptyStrOrNone,
PathOrNone,
)
from .custom.validators import parse_duration_as_timedelta, parse_falsy_as


class DownloadOptions(BaseModel):
Expand All @@ -26,14 +37,7 @@ class DownloadOptions(BaseModel):
separate_posts_format: NonEmptyStr = "{default}"
skip_download_mark_completed: bool = False
skip_referer_seen_before: bool = False
maximum_number_of_children: list[NonNegativeInt] = []

@field_validator("maximum_number_of_children", mode="before")
@classmethod
def handle_falsy(cls, value: list) -> list:
if not value:
return []
return value
maximum_number_of_children: ListNonNegativeInt = []


class Files(AliasModel):
Expand All @@ -44,62 +48,33 @@ class Files(AliasModel):
class Logs(AliasModel):
log_folder: Path = APP_STORAGE / "Configs" / "{config}" / "Logs"
webhook: HttpAppriseURL | None = Field(validation_alias="webhook_url", default=None)
main_log: Path = Field(Path("downloader.log"), validation_alias="main_log_filename")
last_forum_post: Path = Field(Path("Last_Scraped_Forum_Posts.csv"), validation_alias="last_forum_post_filename")
unsupported_urls: Path = Field(Path("Unsupported_URLs.csv"), validation_alias="unsupported_urls_filename")
download_error_urls: Path = Field(Path("Download_Error_URLs.csv"), validation_alias="download_error_urls_filename")
scrape_error_urls: Path = Field(Path("Scrape_Error_URLs.csv"), validation_alias="scrape_error_urls_filename")
main_log: MainLogPath = Field(Path("downloader.log"), validation_alias="main_log_filename")
last_forum_post: LogPath = Field(Path("Last_Scraped_Forum_Posts.csv"), validation_alias="last_forum_post_filename")
unsupported_urls: LogPath = Field(Path("Unsupported_URLs.csv"), validation_alias="unsupported_urls_filename")
download_error_urls: LogPath = Field(Path("Download_Error_URLs.csv"), validation_alias="download_error_urls_filename") # fmt: skip
scrape_error_urls: LogPath = Field(Path("Scrape_Error_URLs.csv"), validation_alias="scrape_error_urls_filename")
rotate_logs: bool = False
log_line_width: PositiveInt = Field(default=240, ge=50)
logs_expire_after: timedelta | None = None

@field_validator("webhook", mode="before")
@classmethod
def handle_falsy(cls, value: str) -> str | None:
if not value:
return None
return value

@field_validator("main_log", mode="after")
@classmethod
def fix_main_log_extension(cls, value: Path) -> Path:
return value.with_suffix(".log")

@field_validator("last_forum_post", "unsupported_urls", "download_error_urls", "scrape_error_urls", mode="after")
@classmethod
def fix_other_logs_extensions(cls, value: Path) -> Path:
return value.with_suffix(".csv")
return parse_falsy_as(value, None)

@field_validator("logs_expire_after", mode="before")
@staticmethod
def parse_logs_duration(input_date: timedelta | str | int | None) -> timedelta:
"""Parses `datetime.timedelta`, `str` or `int` into a timedelta format.
for `str`, the expected format is `value unit`, ex: `5 days`, `10 minutes`, `1 year`
valid units:
year(s), week(s), day(s), hour(s), minute(s), second(s), millisecond(s), microsecond(s)
for `int`, value is assumed as `days`
"""
if input_date is None:
return None
return parse_duration_to_timedelta(input_date)
def parse_logs_duration(input_date: timedelta | str | int | None) -> timedelta | str | None:
return parse_falsy_as(input_date, None, parse_duration_as_timedelta)


class FileSizeLimits(BaseModel):
maximum_image_size: ByteSize = ByteSize(0)
maximum_other_size: ByteSize = ByteSize(0)
maximum_video_size: ByteSize = ByteSize(0)
minimum_image_size: ByteSize = ByteSize(0)
minimum_other_size: ByteSize = ByteSize(0)
minimum_video_size: ByteSize = ByteSize(0)

@field_serializer("*")
def human_readable(self, value: ByteSize | int) -> str:
if not isinstance(value, ByteSize):
value = ByteSize(value)
return value.human_readable(decimal=True)
maximum_image_size: ByteSizeSerilized = ByteSize(0)
maximum_other_size: ByteSizeSerilized = ByteSize(0)
maximum_video_size: ByteSizeSerilized = ByteSize(0)
minimum_image_size: ByteSizeSerilized = ByteSize(0)
minimum_other_size: ByteSizeSerilized = ByteSize(0)
minimum_video_size: ByteSizeSerilized = ByteSize(0)


class IgnoreOptions(BaseModel):
Expand All @@ -108,16 +83,9 @@ class IgnoreOptions(BaseModel):
exclude_audio: bool = False
exclude_other: bool = False
ignore_coomer_ads: bool = False
skip_hosts: list[NonEmptyStr] = []
only_hosts: list[NonEmptyStr] = []
filename_regex_filter: NonEmptyStr | None = None

@field_validator("skip_hosts", "only_hosts", mode="before")
@classmethod
def handle_falsy(cls, value: list) -> list:
if not value:
return []
return value
skip_hosts: ListNonEmptyStr = []
only_hosts: ListNonEmptyStr = []
filename_regex_filter: NonEmptyStrOrNone = None


class RuntimeOptions(BaseModel):
Expand All @@ -129,50 +97,22 @@ class RuntimeOptions(BaseModel):
delete_partial_files: bool = False
update_last_forum_post: bool = True
send_unsupported_to_jdownloader: bool = False
jdownloader_download_dir: Path | None = None
jdownloader_download_dir: PathOrNone = None
jdownloader_autostart: bool = False
jdownloader_whitelist: list[NonEmptyStr] = []
jdownloader_whitelist: ListNonEmptyStr = []
deep_scrape: bool = False
slow_download_speed: ByteSize = ByteSize(0)

@field_validator("jdownloader_download_dir", mode="before")
@classmethod
def handle_falsy(cls, value: str) -> str | None:
if not value or value == "None":
return None
return value

@field_validator("jdownloader_whitelist", mode="before")
@classmethod
def handle_list(cls, value: list) -> list:
if not value:
return []
return value

@field_serializer("slow_download_speed")
def human_readable(self, value: ByteSize | int) -> str:
if not isinstance(value, ByteSize):
value = ByteSize(value)
return value.human_readable(decimal=True)
slow_download_speed: ByteSizeSerilized = ByteSize(0)


# TODO: allow None values in sorting format to skip that type of file
class Sorting(BaseModel):
sort_downloads: bool = False
sort_folder: Path = DOWNLOAD_STORAGE / "Cyberdrop-DL Sorted Downloads"
scan_folder: Path | None = None
scan_folder: PathOrNone = None
sort_incrementer_format: NonEmptyStr = " ({i})"
sorted_audio: NonEmptyStr | None = "{sort_dir}/{base_dir}/Audio/{filename}{ext}"
sorted_image: NonEmptyStr | None = "{sort_dir}/{base_dir}/Images/{filename}{ext}"
sorted_other: NonEmptyStr | None = "{sort_dir}/{base_dir}/Other/{filename}{ext}"
sorted_video: NonEmptyStr | None = "{sort_dir}/{base_dir}/Videos/{filename}{ext}"

@field_validator("scan_folder", "sorted_audio", "sorted_image", "sorted_other", "sorted_video", mode="before")
@classmethod
def handle_falsy(cls, value: str) -> str | None:
if not value or value == "None":
return None
return value
sorted_audio: NonEmptyStrOrNone = "{sort_dir}/{base_dir}/Audio/{filename}{ext}"
sorted_image: NonEmptyStrOrNone = "{sort_dir}/{base_dir}/Images/{filename}{ext}"
sorted_other: NonEmptyStrOrNone = "{sort_dir}/{base_dir}/Other/{filename}{ext}"
sorted_video: NonEmptyStrOrNone = "{sort_dir}/{base_dir}/Videos/{filename}{ext}"


class BrowserCookies(BaseModel):
Expand All @@ -183,14 +123,13 @@ class BrowserCookies(BaseModel):
@field_validator("browsers", "sites", mode="before")
@classmethod
def handle_list(cls, values: list) -> list:
if not values:
return []
values = parse_falsy_as(values, [])
if isinstance(values, list):
return [str(value).lower() for value in values]
return values


class DupeCleanupOptions(BaseModel):
class DupeCleanup(BaseModel):
hashing: Hashing = Hashing.IN_PLACE
auto_dedupe: bool = True
add_md5_hash: bool = False
Expand All @@ -201,12 +140,10 @@ class DupeCleanupOptions(BaseModel):
class ConfigSettings(AliasModel):
browser_cookies: BrowserCookies = Field(validation_alias="Browser_Cookies", default=BrowserCookies())
download_options: DownloadOptions = Field(validation_alias="Download_Options", default=DownloadOptions())
dupe_cleanup_options: DupeCleanupOptions = Field(
validation_alias="Dupe_Cleanup_Options", default=DupeCleanupOptions()
)
dupe_cleanup_options: DupeCleanup = Field(validation_alias="Dupe_Cleanup_Options", default=DupeCleanup())
file_size_limits: FileSizeLimits = Field(validation_alias="File_Size_Limits", default=FileSizeLimits())
files: Files = Field(validation_alias="Files", default=Files())
ignore_options: IgnoreOptions = Field(validation_alias="Ignore_Options", default=IgnoreOptions())
logs: Logs = Field(validation_alias="Logs", default=Logs())
logs: Logs = Field(validation_alias="Logs", default=Logs()) # type: ignore
runtime_options: RuntimeOptions = Field(validation_alias="Runtime_Options", default=RuntimeOptions())
sorting: Sorting = Field(validation_alias="Sorting", default=Sorting())
64 changes: 64 additions & 0 deletions cyberdrop_dl/config_definitions/custom/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Functions to use with `AfterValidator`, `field_validator(mode="after")` or `model_validator(mode="after")`
"""

import re
from datetime import timedelta
from pathlib import Path

from pydantic import AnyUrl, ByteSize, TypeAdapter
from yarl import URL

DATE_PATTERN_REGEX = r"(\d+)\s*(second|seconds|minute|minutes|hour|hours|day|days|week|weeks|month|months|year|years)"
DATE_PATTERN = re.compile(DATE_PATTERN_REGEX, re.IGNORECASE)

byte_size_adapter = TypeAdapter(ByteSize)


def convert_byte_size_to_str(value: ByteSize) -> str:
if not isinstance(value, ByteSize):
value = ByteSize(value)
return value.human_readable(decimal=True)


def convert_to_yarl(value: AnyUrl) -> URL:
return URL(str(value))


def change_path_suffix(value: Path, suffix: str) -> Path:
return value.with_suffix(suffix)


def convert_to_byte_size(value: ByteSize | str | int) -> ByteSize:
return byte_size_adapter.validate_python(value)


def convert_str_to_timedelta(input_date: str) -> timedelta:
time_str = input_date.casefold()
matches: list[str] = re.findall(DATE_PATTERN, time_str)
seen_units = set()
time_dict = {"days": 0}

for value, unit in matches:
value = int(value)
unit = unit.lower()
normalized_unit = unit.rstrip("s")
plural_unit = normalized_unit + "s"
if normalized_unit in seen_units:
msg = f"Duplicate time unit detected: '{unit}' conflicts with another entry"
raise ValueError(msg)
seen_units.add(normalized_unit)

if "day" in unit:
time_dict["days"] += value
elif "month" in unit:
time_dict["days"] += value * 30
elif "year" in unit:
time_dict["days"] += value * 365
else:
time_dict[plural_unit] = value

if not matches:
msg = f"Unable to convert '{input_date}' to timedelta object"
raise ValueError(msg)
return timedelta(**time_dict)
70 changes: 70 additions & 0 deletions cyberdrop_dl/config_definitions/custom/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Annotated

from pydantic import (
AfterValidator,
AnyUrl,
BaseModel,
BeforeValidator,
ByteSize,
ConfigDict,
HttpUrl,
NonNegativeInt,
PlainSerializer,
Secret,
SerializationInfo,
StringConstraints,
model_serializer,
model_validator,
)

from .converters import change_path_suffix, convert_byte_size_to_str, convert_to_yarl
from .validators import parse_apprise_url, parse_falsy_as_none, parse_list

if TYPE_CHECKING:
from yarl import URL

ByteSizeSerilized = Annotated[ByteSize, PlainSerializer(convert_byte_size_to_str, return_type=str)]
HttpURL = Annotated[HttpUrl, AfterValidator(convert_to_yarl)]
ListNonNegativeInt = Annotated[list[NonNegativeInt], BeforeValidator(parse_list)]

NonEmptyStr = Annotated[str, StringConstraints(min_length=1, strip_whitespace=True)]
NonEmptyStrOrNone = Annotated[NonEmptyStr | None, BeforeValidator(parse_falsy_as_none)]
ListNonEmptyStr = Annotated[list[NonEmptyStr], BeforeValidator(parse_list)]

PathOrNone = Annotated[Path | None, BeforeValidator(parse_falsy_as_none)]
LogPath = Annotated[Path, AfterValidator(partial(change_path_suffix, suffix=".csv"))]
MainLogPath = Annotated[LogPath, AfterValidator(partial(change_path_suffix, suffix=".log"))]


class AliasModel(BaseModel):
model_config = ConfigDict(populate_by_name=True)


class FrozenModel(BaseModel):
model_config = ConfigDict(frozen=True)


class AppriseURLModel(FrozenModel):
url: Secret[AnyUrl]
tags: set[str]

@model_serializer()
def serialize(self, info: SerializationInfo):
dump_secret = info.mode != "json"
url = self.url.get_secret_value() if dump_secret else self.url
tags = self.tags - set("no_logs")
tags = sorted(tags)
return f"{','.join(tags)}{'=' if tags else ''}{url}"

@model_validator(mode="before")
@staticmethod
def parse_input(value: URL | dict | str) -> dict:
return parse_apprise_url(value)


class HttpAppriseURL(AppriseURLModel):
url: Secret[HttpURL]
Loading

0 comments on commit 1c5b19b

Please sign in to comment.