Skip to content

Commit

Permalink
Return more information in create_commit output (#1066)
Browse files Browse the repository at this point in the history
* Return more information in create_commit output

* flake8

* requested changes

* fix autocomplete test

* Add pr_revision and pr_url to CommitInfo

* Update tests/test_hf_api.py

Co-authored-by: Omar Sanseviero <[email protected]>

* nicely handle properties in dataclass

* make style

Co-authored-by: Omar Sanseviero <[email protected]>
  • Loading branch information
Wauplin and osanseviero authored Sep 23, 2022
1 parent e8801bd commit 5958f17
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 21 deletions.
22 changes: 22 additions & 0 deletions docs/source/package_reference/hf_api.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,30 @@ models = hf_api.list_models()

Using the `HfApi` class directly enables you to set a different endpoint to that of the Hugging Face's Hub.

### HfApi

[[autodoc]] HfApi

### ModelInfo

[[autodoc]] huggingface_hub.hf_api.ModelInfo

### DatasetInfo

[[autodoc]] huggingface_hub.hf_api.DatasetInfo

### SpaceInfo

[[autodoc]] huggingface_hub.hf_api.SpaceInfo

### RepoFile

[[autodoc]] huggingface_hub.hf_api.RepoFile

### CommitInfo

[[autodoc]] huggingface_hub.hf_api.CommitInfo

## `create_commit` API

Below are the supported values for [`CommitOperation`]:
Expand All @@ -56,10 +70,18 @@ It does this using the [`HfFolder`] utility, which saves data at the root of the

Some helpers to filter repositories on the Hub are available in the `huggingface_hub` package.

### DatasetFilter

[[autodoc]] DatasetFilter

### ModelFilter

[[autodoc]] ModelFilter

### DatasetSearchArguments

[[autodoc]] DatasetSearchArguments

### ModelSearchArguments

[[autodoc]] ModelSearchArguments
2 changes: 2 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
"try_to_load_from_cache",
],
"hf_api": [
"CommitInfo",
"CommitOperation",
"CommitOperationAdd",
"CommitOperationDelete",
Expand Down Expand Up @@ -306,6 +307,7 @@ def __dir__():
from .file_download import hf_hub_download # noqa: F401
from .file_download import hf_hub_url # noqa: F401
from .file_download import try_to_load_from_cache # noqa: F401
from .hf_api import CommitInfo # noqa: F401
from .hf_api import CommitOperation # noqa: F401
from .hf_api import CommitOperationAdd # noqa: F401
from .hf_api import CommitOperationDelete # noqa: F401
Expand Down
88 changes: 76 additions & 12 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import subprocess
import warnings
from dataclasses import dataclass, field
from typing import BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from urllib.parse import quote

Expand Down Expand Up @@ -171,6 +172,62 @@ class BlobLfsInfo(TypedDict, total=False):
sha256: str


@dataclass
class CommitInfo:
"""Data structure containing information about a newly created commit.
Returned by [`create_commit`].
Args:
commit_url (`str`):
Url where to find the commit.
commit_message (`str`):
The summary (first line) of the commit that has been created.
commit_description (`str`):
Description of the commit that has been created. Can be empty.
oid (`str`):
Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`.
pr_url (`str`, *optional*):
Url to the PR that has been created, if any. Populated when `create_pr=True`
is passed.
pr_revision (`str`, *optional*):
Revision of the PR that has been created, if any. Populated when
`create_pr=True` is passed. Example: `"refs/pr/1"`.
pr_num (`int`, *optional*):
Number of the PR discussion that has been created, if any. Populated when
`create_pr=True` is passed. Can be passed as `discussion_num` in
[`get_discussion_details`]. Example: `1`.
"""

commit_url: str
commit_message: str
commit_description: str
oid: str
pr_url: Optional[str] = None

# Computed from `pr_url` in `__post_init__`
pr_revision: Optional[str] = field(init=False)
pr_num: Optional[str] = field(init=False)

def __post_init__(self):
"""Populate pr-related fields after initialization.
See https://docs.python.org/3.10/library/dataclasses.html#post-init-processing.
"""
if self.pr_url is not None:
self.pr_revision = _parse_revision_from_pr_url(self.pr_url)
self.pr_num = int(self.pr_revision.split("/")[-1])
else:
self.pr_revision = None
self.pr_num = None


class RepoFile:
"""
Data structure that represents a public file inside a repo, accessible from
Expand Down Expand Up @@ -1850,7 +1907,7 @@ def create_commit(
create_pr: Optional[bool] = None,
num_threads: int = 5,
parent_commit: Optional[str] = None,
) -> Optional[str]:
) -> CommitInfo:
"""
Creates a commit in the given repo, deleting & uploading files as needed.
Expand Down Expand Up @@ -1902,9 +1959,9 @@ def create_commit(
if the repo is updated / committed to concurrently.
Returns:
`str` or `None`:
If `create_pr` is `True`, returns the URL to the newly created Pull Request
on the Hub. Otherwise returns `None`.
[`CommitInfo`]:
Instance of [`CommitInfo`] containing information about the newly
created commit (commit hash, commit url, pr url, commit message,...).
Raises:
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
Expand Down Expand Up @@ -2015,7 +2072,14 @@ def create_commit(
params={"create_pr": "1"} if create_pr else None,
)
hf_raise_for_status(commit_resp, endpoint_name="commit")
return commit_resp.json().get("pullRequestUrl", None)
commit_data = commit_resp.json()
return CommitInfo(
commit_url=commit_data["commitUrl"],
commit_message=commit_message,
commit_description=commit_description,
oid=commit_data["commitOid"],
pr_url=commit_data["pullRequestUrl"] if create_pr else None,
)

@validate_hf_hub_args
def upload_file(
Expand Down Expand Up @@ -2157,7 +2221,7 @@ def upload_file(
path_in_repo=path_in_repo,
)

pr_url = self.create_commit(
commit_info = self.create_commit(
repo_id=repo_id,
repo_type=repo_type,
operations=[operation],
Expand All @@ -2169,8 +2233,8 @@ def upload_file(
parent_commit=parent_commit,
)

if pr_url is not None:
revision = quote(_parse_revision_from_pr_url(pr_url), safe="")
if commit_info.pr_url is not None:
revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="")
if repo_type in REPO_TYPES_URL_PREFIXES:
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
revision = revision if revision is not None else DEFAULT_REVISION
Expand Down Expand Up @@ -2317,7 +2381,7 @@ def upload_folder(
ignore_patterns=ignore_patterns,
)

pr_url = self.create_commit(
commit_info = self.create_commit(
repo_type=repo_type,
repo_id=repo_id,
operations=files_to_add,
Expand All @@ -2329,8 +2393,8 @@ def upload_folder(
parent_commit=parent_commit,
)

if pr_url is not None:
revision = quote(_parse_revision_from_pr_url(pr_url), safe="")
if commit_info.pr_url is not None:
revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="")
if repo_type in REPO_TYPES_URL_PREFIXES:
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
revision = revision if revision is not None else DEFAULT_REVISION
Expand All @@ -2350,7 +2414,7 @@ def delete_file(
commit_description: Optional[str] = None,
create_pr: Optional[bool] = None,
parent_commit: Optional[str] = None,
):
) -> CommitInfo:
"""
Deletes a file in the given repo.
Expand Down
6 changes: 3 additions & 3 deletions src/huggingface_hub/keras_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def push_to_hub_keras(
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
pr_url = api.create_commit(
commit_info = api.create_commit(
repo_type="model",
repo_id=repo_id,
operations=operations,
Expand All @@ -458,8 +458,8 @@ def push_to_hub_keras(
revision = branch
if revision is None:
revision = (
quote(_parse_revision_from_pr_url(pr_url), safe="")
if pr_url is not None
quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="")
if commit_info.pr_url is not None
else DEFAULT_REVISION
)
return f"{api.endpoint}/{repo_id}/tree/{revision}/"
Expand Down
8 changes: 6 additions & 2 deletions src/huggingface_hub/utils/_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
from functools import wraps
from itertools import chain
from typing import Callable
from typing import TypeVar


REPO_ID_REGEX = re.compile(
Expand All @@ -40,7 +40,11 @@ class HFValidationError(ValueError):
"""


def validate_hf_hub_args(fn: Callable) -> Callable:
# type hint meaning "function signature not changed by decorator"
CallableT = TypeVar("CallableT") # callable type


def validate_hf_hub_args(fn: CallableT) -> CallableT:
"""Validate values received as argument for any public method of `huggingface_hub`.
The goal of this decorator is to harmonize validation of arguments reused
Expand Down
27 changes: 24 additions & 3 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from huggingface_hub.file_download import cached_download, hf_hub_download
from huggingface_hub.hf_api import (
USERNAME_PLACEHOLDER,
CommitInfo,
DatasetInfo,
DatasetSearchArguments,
HfApi,
Expand Down Expand Up @@ -766,10 +767,26 @@ def test_create_commit_create_pr(self):
token=self._token,
create_pr=True,
)

# Check commit info
self.assertIsInstance(resp, CommitInfo)
commit_id = resp.oid
self.assertIn("pr_revision='refs/pr/1'", str(resp))
self.assertIsInstance(commit_id, str)
self.assertGreater(len(commit_id), 0)
self.assertEqual(
resp.commit_url,
f"{self._api.endpoint}/{USER}/{REPO_NAME}/commit/{commit_id}",
)
self.assertEqual(resp.commit_message, "Test create_commit")
self.assertEqual(resp.commit_description, "")
self.assertEqual(
resp,
resp.pr_url,
f"{self._api.endpoint}/{USER}/{REPO_NAME}/discussions/1",
)
self.assertEqual(resp.pr_num, 1)
self.assertEqual(resp.pr_revision, "refs/pr/1")

with self.assertRaises(HTTPError) as ctx:
# Should raise a 404
hf_hub_download(
Expand Down Expand Up @@ -830,13 +847,17 @@ def test_create_commit(self):
path_or_fileobj=self.tmp_file,
),
]
return_val = self._api.create_commit(
resp = self._api.create_commit(
operations=operations,
commit_message="Test create_commit",
repo_id=f"{USER}/{REPO_NAME}",
token=self._token,
)
self.assertIsNone(return_val)
# Check commit info
self.assertIsInstance(resp, CommitInfo)
self.assertIsNone(resp.pr_url) # No pr created
self.assertIsNone(resp.pr_num)
self.assertIsNone(resp.pr_revision)
with self.assertRaises(HTTPError):
# Should raise a 404
hf_hub_download(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_init_lazy_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_autocomplete_on_root_imports(self) -> None:
self.assertTrue(
signature_list[0]
.docstring()
.startswith("create_commit(self, repo_id: str")
.startswith("create_commit(repo_id: str,")
)
break
else:
Expand Down

0 comments on commit 5958f17

Please sign in to comment.