Skip to content

Commit

Permalink
huggingface-cli upload - Validate README.md before file hashing
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Aug 16, 2024
1 parent dfd73c0 commit 397f9d4
Showing 1 changed file with 85 additions and 57 deletions.
142 changes: 85 additions & 57 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3797,26 +3797,10 @@ def create_commit(
for addition in additions:
if addition.path_in_repo == "README.md":
with addition.as_file() as file:
response = get_session().post(
f"{ENDPOINT}/api/validate-yaml",
json={"content": file.read().decode(), "repoType": repo_type},
headers=headers,
)
# Handle warnings (example: empty metadata)
response_content = response.json()
message = "\n".join(
[f"- {warning.get('message')}" for warning in response_content.get("warnings", [])]
)
if message:
warnings.warn(f"Warnings while validating metadata in README.md:\n{message}")

# Raise on errors
try:
hf_raise_for_status(response)
except BadRequestError as e:
errors = response_content.get("errors", [])
message = "\n".join([f"- {error.get('message')}" for error in errors])
raise ValueError(f"Invalid metadata in README.md.\n{message}") from e
content = file.read().decode()
self._validate_yaml(content, repo_type=repo_type, token=token)
# Skip other additions after `README.md` has been processed
break

# If updating twice the same file or update then delete a file in a single commit
_warn_on_overwriting_operations(operations)
Expand Down Expand Up @@ -4875,11 +4859,13 @@ def upload_folder(
path_in_repo=path_in_repo,
delete_patterns=delete_patterns,
)
add_operations = _prepare_upload_folder_additions(
add_operations = self._prepare_upload_folder_additions(
folder_path,
path_in_repo,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
token=token,
repo_type=repo_type,
)

# Optimize operations: if some files will be overwritten, we don't need to delete them first
Expand Down Expand Up @@ -9182,6 +9168,84 @@ def _prepare_folder_deletions(
if relpath_to_abspath[relpath] != ".gitattributes"
]

def _prepare_upload_folder_additions(
self,
folder_path: Union[str, Path],
path_in_repo: str,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
repo_type: Optional[str] = None,
token: Union[bool, str, None] = None,
) -> List[CommitOperationAdd]:
"""Generate the list of Add operations for a commit to upload a folder.
Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist)
constraints are discarded.
"""

folder_path = Path(folder_path).expanduser().resolve()
if not folder_path.is_dir():
raise ValueError(f"Provided path: '{folder_path}' is not a directory")

# List files from folder
relpath_to_abspath = {
path.relative_to(folder_path).as_posix(): path
for path in sorted(folder_path.glob("**/*")) # sorted to be deterministic
if path.is_file()
}

# Filter files
# Patterns are applied on the path relative to `folder_path`. `path_in_repo` is prefixed after the filtering.
filtered_repo_objects = list(
filter_repo_objects(
relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
)

prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ""

# If updating a README.md file, make sure the metadata format is valid
# It's better to fail early than to fail after all the files have been hashed.
for relpath in filtered_repo_objects:
if relpath == "README.md":
self._validate_yaml(
content=relpath_to_abspath["README.md"].read_text(),
repo_type=repo_type,
token=token,
)
# Skip other additions after `README.md` has been processed
break

return [
CommitOperationAdd(
path_or_fileobj=relpath_to_abspath[relpath], # absolute path on disk
path_in_repo=prefix + relpath, # "absolute" path in repo
)
for relpath in filtered_repo_objects
]

def _validate_yaml(self, content: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None):
headers = self._build_hf_headers(token=token)

response = get_session().post(
f"{ENDPOINT}/api/validate-yaml",
json={"content": content, "repoType": repo_type},
headers=headers,
)
# Handle warnings (example: empty metadata)
response_content = response.json()
message = "\n".join([f"- {warning.get('message')}" for warning in response_content.get("warnings", [])])
if message:
warnings.warn(f"Warnings while validating metadata in README.md:\n{message}")

# Raise on errors
try:
hf_raise_for_status(response)
except BadRequestError as e:
errors = response_content.get("errors", [])
message = "\n".join([f"- {error.get('message')}" for error in errors])
raise ValueError(f"Invalid metadata in README.md.\n{message}") from e

def get_user_overview(self, username: str) -> User:
"""
Get an overview of a user on the Hub.
Expand Down Expand Up @@ -9275,42 +9339,6 @@ def list_user_following(self, username: str) -> Iterable[User]:
yield User(**followed_user)


def _prepare_upload_folder_additions(
folder_path: Union[str, Path],
path_in_repo: str,
allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None,
) -> List[CommitOperationAdd]:
"""Generate the list of Add operations for a commit to upload a folder.
Files not matching the `allow_patterns` (allowlist) and `ignore_patterns` (denylist)
constraints are discarded.
"""
folder_path = Path(folder_path).expanduser().resolve()
if not folder_path.is_dir():
raise ValueError(f"Provided path: '{folder_path}' is not a directory")

# List files from folder
relpath_to_abspath = {
path.relative_to(folder_path).as_posix(): path
for path in sorted(folder_path.glob("**/*")) # sorted to be deterministic
if path.is_file()
}

# Filter files and return
# Patterns are applied on the path relative to `folder_path`. `path_in_repo` is prefixed after the filtering.
prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ""
return [
CommitOperationAdd(
path_or_fileobj=relpath_to_abspath[relpath], # absolute path on disk
path_in_repo=prefix + relpath, # "absolute" path in repo
)
for relpath in filter_repo_objects(
relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
]


def _parse_revision_from_pr_url(pr_url: str) -> str:
"""Safely parse revision number from a PR url.
Expand Down

0 comments on commit 397f9d4

Please sign in to comment.