Skip to content

Commit

Permalink
Use pydantic to validate config schema (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
piotrekkr authored Apr 17, 2024
1 parent 66ae7bd commit 018b48b
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 243 deletions.
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,6 @@ You can remove this file or selected lines from inside this file. This will trig

## TODO

- Use pydantic instead of schema
- Add cli options to force re-download versions and document link lists
- Add download size to stats
- Add `--validate` option to `download` command that will trigger validation after download
- Add some example terraform and/or ansible to use for deploy to VM in cloud
- Use proper logging
150 changes: 122 additions & 28 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ python = "^3.11"
simple-salesforce = "^1.12.5"
PyYAML = "^6.0.1"
click = "^8.1.7"
schema = "^0.7.5"
python-dateutil = "^2.8.2"
types-PyYAML = "^6.0.12.12"
humanize = "^4.9.0"
pydantic = "^2.7.0"

[tool.poetry.group.dev.dependencies]
pytest-mock = "^3.12.0"
Expand Down
178 changes: 49 additions & 129 deletions src/salesforce_archivist/archivist.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import datetime
import os.path
from typing import Any
import os
from typing import Any, Dict

import click
import humanize
import yaml
from schema import And, Optional, Or, Schema, Use
from pydantic import BaseModel, Field, field_validator, ValidationInfo, computed_field
from typing import Optional
from typing_extensions import Annotated
from simple_salesforce import Salesforce as SalesforceClient

from salesforce_archivist.salesforce.api import SalesforceApiClient
Expand All @@ -14,40 +15,12 @@
from salesforce_archivist.salesforce.validation import ValidatedContentVersionList


class ArchivistObject:
def __init__(
self,
data_dir: str,
obj_type: str,
modified_date_lt: datetime.datetime | None = None,
modified_date_gt: datetime.datetime | None = None,
dir_name_field: str | None = None,
):
self._data_dir: str = os.path.join(data_dir, obj_type)
self._obj_type: str = obj_type
self._modified_date_lt: datetime.datetime | None = modified_date_lt
self._modified_date_gt: datetime.datetime | None = modified_date_gt
self._dir_name_field: str | None = dir_name_field

@property
def data_dir(self) -> str:
return self._data_dir

@property
def obj_type(self) -> str:
return self._obj_type

@property
def modified_date_lt(self) -> datetime.datetime | None:
return self._modified_date_lt

@property
def modified_date_gt(self) -> datetime.datetime | None:
return self._modified_date_gt

@property
def dir_name_field(self) -> str | None:
return self._dir_name_field
class ArchivistObject(BaseModel):
data_dir: Annotated[str, Field(min_length=1)]
obj_type: Annotated[str, Field(min_length=1)]
modified_date_lt: Optional[datetime.datetime] = None
modified_date_gt: Optional[datetime.datetime] = None
dir_name_field: Optional[str] = None

def __eq__(self, other: Any) -> bool:
if not isinstance(other, type(self)):
Expand All @@ -60,102 +33,49 @@ def __eq__(self, other: Any) -> bool:
other.modified_date_lt,
)


class ArchivistAuth:
def __init__(self, instance_url: str, username: str, consumer_key: str, private_key: str):
self._instance_url = instance_url
self._username = username
self._consumer_key = consumer_key
self._private_key = private_key

@property
def instance_url(self) -> str:
return self._instance_url

# https://github.com/python/mypy/issues/14461
@computed_field # type: ignore[misc]
@property
def username(self) -> str:
return self._username

@property
def consumer_key(self) -> str:
return self._consumer_key

@property
def private_key(self) -> str:
return self._private_key


class ArchivistConfig:
_schema = Schema(
{
"data_dir": And(str, len, os.path.isdir, error="data_dir must be set and be a directory"),
"max_api_usage_percent": Or(int, float, Use(float), lambda v: 0.0 < v <= 100.0),
Optional("max_workers"): Optional(int, lambda v: 0 < v),
Optional("modified_date_gt"): lambda d: isinstance(d, datetime.datetime),
Optional("modified_date_lt"): lambda d: isinstance(d, datetime.datetime),
"auth": {
"instance_url": And(str, len),
"username": And(str, len),
"consumer_key": And(str, len),
"private_key": And(bytes, len, Use(lambda b: b.decode("UTF-8"))),
},
"objects": {
str: {
Optional("modified_date_gt"): lambda d: isinstance(d, datetime.datetime),
Optional("modified_date_lt"): lambda d: isinstance(d, datetime.datetime),
Optional("dir_name_field"): And(str, len),
def obj_dir(self) -> str:
return os.path.join(self.data_dir, self.obj_type)


class ArchivistAuth(BaseModel):
instance_url: Annotated[str, Field(min_length=1)]
username: Annotated[str, Field(min_length=1)]
consumer_key: Annotated[str, Field(min_length=1)]
private_key: Annotated[str, Field(min_length=1)]


class ArchivistConfig(BaseModel):
auth: ArchivistAuth
data_dir: Annotated[str, Field(min_length=1)]
max_api_usage_percent: Optional[Annotated[float, Field(gt=0.0, le=100.0)]] = None
max_workers: Optional[Annotated[int, Field(gt=0)]] = None
modified_date_gt: Optional[datetime.datetime] = None
modified_date_lt: Optional[datetime.datetime] = None
objects: Dict[str, ArchivistObject]

@field_validator("objects", mode="before")
@classmethod
def serialize_categories(cls, objects: dict, info: ValidationInfo) -> dict:
for obj_type, obj_dict in objects.items():
obj_dict.update(
{
"obj_type": obj_type,
"data_dir": info.data["data_dir"],
"modified_date_gt": obj_dict.get("modified_date_gt", info.data["modified_date_gt"]),
"modified_date_lt": obj_dict.get("modified_date_lt", info.data["modified_date_lt"]),
}
},
}
)

def __init__(self, path: str):
with open(path) as file:
config = self._schema.validate(yaml.load(file, Loader=yaml.FullLoader))
self._auth: ArchivistAuth = ArchivistAuth(**config["auth"])
self._data_dir: str = config["data_dir"]
self._max_api_usage_percent: float = config["max_api_usage_percent"]
self.modified_date_gt: datetime.datetime | None = config.get("modified_date_gt")
self.modified_date_lt: datetime.datetime | None = config.get("modified_date_lt")
self._max_workers: int = config.get("max_workers")
self._objects = []
for obj_type, obj_config in config["objects"].items():
self._objects.append(
ArchivistObject(
data_dir=self._data_dir,
obj_type=obj_type,
modified_date_lt=obj_config.get("modified_date_lt", self.modified_date_lt),
modified_date_gt=obj_config.get("modified_date_gt", self.modified_date_gt),
dir_name_field=obj_config.get("dir_name_field"),
)
)

@property
def data_dir(self) -> str:
return self._data_dir

@property
def max_workers(self) -> int:
return self._max_workers

@property
def max_api_usage_percent(self) -> float:
return self._max_api_usage_percent

@property
def auth(self) -> ArchivistAuth:
return self._auth

@property
def objects(self) -> list[ArchivistObject]:
return self._objects
return objects


class Archivist:
def __init__(
self,
data_dir: str,
objects: list[ArchivistObject],
objects: dict[str, ArchivistObject],
sf_client: SalesforceClient,
max_api_usage_percent: float | None = None,
max_workers: int | None = None,
Expand All @@ -177,7 +97,7 @@ def download(self) -> bool:
"errors": 0,
"size": 0,
}
for archivist_obj in self._objects:
for archivist_obj in self._objects.values():
obj_type = archivist_obj.obj_type
salesforce = Salesforce(
archivist_obj=archivist_obj,
Expand All @@ -194,7 +114,7 @@ def download(self) -> bool:
download_list = DownloadContentVersionList(
document_link_list=document_link_list,
content_version_list=content_version_list,
data_dir=archivist_obj.data_dir,
data_dir=archivist_obj.obj_dir,
)
stats = salesforce.download_files(
download_content_version_list=download_list,
Expand Down Expand Up @@ -225,7 +145,7 @@ def validate(self) -> bool:
"processed": 0,
"invalid": 0,
}
for archivist_obj in self._objects:
for archivist_obj in self._objects.values():
salesforce = Salesforce(
archivist_obj=archivist_obj,
client=SalesforceApiClient(self._sf_client),
Expand All @@ -238,7 +158,7 @@ def validate(self) -> bool:
download_list = DownloadContentVersionList(
document_link_list=document_link_list,
content_version_list=content_version_list,
data_dir=archivist_obj.data_dir,
data_dir=archivist_obj.obj_dir,
)
stats = salesforce.validate_download(
download_content_version_list=download_list,
Expand Down
5 changes: 4 additions & 1 deletion src/salesforce_archivist/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from types import FrameType

import click
import yaml
from click import Context

from salesforce_archivist.archivist import Archivist, ArchivistConfig
Expand All @@ -20,7 +21,9 @@ def signal_handler(signum: int, frame: FrameType | None) -> None:
@click.pass_context
def cli(ctx: Context) -> None:
ctx.ensure_object(dict)
ctx.obj["config"] = ArchivistConfig("config.yaml")
with open("config.yaml") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
ctx.obj["config"] = ArchivistConfig(**config)


@cli.command()
Expand Down
6 changes: 3 additions & 3 deletions src/salesforce_archivist/salesforce/salesforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
self._dir_name_field = dir_name_field

def _init_tmp_dir(self) -> str:
tmp_dir = os.path.join(self._archivist_obj.data_dir, "tmp")
tmp_dir = os.path.join(self._archivist_obj.obj_dir, "tmp")
os.makedirs(tmp_dir, exist_ok=True)
for entry in os.scandir(tmp_dir):
if entry.is_file():
Expand Down Expand Up @@ -90,7 +90,7 @@ def download_content_document_link_list(

def load_content_document_link_list(self) -> ContentDocumentLinkList:
document_link_list = ContentDocumentLinkList(
data_dir=self._archivist_obj.data_dir, dir_name_field=self._archivist_obj.dir_name_field
data_dir=self._archivist_obj.obj_dir, dir_name_field=self._archivist_obj.dir_name_field
)
if not document_link_list.data_file_exist():
self.download_content_document_link_list(document_link_list=document_link_list)
Expand All @@ -105,7 +105,7 @@ def load_content_version_list(
document_link_list: ContentDocumentLinkList,
batch_size: int = 3000,
) -> ContentVersionList:
content_version_list = ContentVersionList(data_dir=self._archivist_obj.data_dir)
content_version_list = ContentVersionList(data_dir=self._archivist_obj.obj_dir)
if not content_version_list.data_file_exist():
doc_id_list = [link.content_document_id for link in document_link_list]
list_size = len(doc_id_list)
Expand Down
Loading

0 comments on commit 018b48b

Please sign in to comment.