Skip to content

Commit

Permalink
πŸ‘¨β€πŸ’» Configure HF Hub URL with environment variable (#815)
Browse files Browse the repository at this point in the history
* Make ENDPOINT configurable via an environment variable

* ⬆ Needs requests>=2.27 for JSONDecodeError

See https://docs.python-requests.org/en/latest/community/updates/#id2

* πŸ”§ Add private param to repo_type_and_id_from_hf_id

Only for testing purposes

* πŸ’„ Code quality

* 🩹 Parentheses are important

* πŸ‘Œ Suggested implementation

* πŸ”₯ We don't need this anymore, do we ?

* Rename hf_api to client
  • Loading branch information
SBrandeis authored Apr 4, 2022
1 parent b287fba commit 9edcb30
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 30 deletions.
8 changes: 2 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def get_version() -> str:

install_requires = [
"filelock",
"requests",
"requests>=2.27",
"tqdm",
"pyyaml",
"typing-extensions>=3.7.4.3", # to be able to import TypeAlias
Expand All @@ -27,11 +27,7 @@ def get_version() -> str:
"torch",
]

extras["tensorflow"] = [
"tensorflow",
"pydot",
"graphviz"
]
extras["tensorflow"] = ["tensorflow", "pydot", "graphviz"]

extras["testing"] = [
"pytest",
Expand Down
3 changes: 1 addition & 2 deletions src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
os.environ.get("HUGGINGFACE_CO_STAGING", "NO").upper() in ENV_VARS_TRUE_VALUES
)

ENDPOINT = (
ENDPOINT = os.getenv("HF_HUB_URL") or (
"https://moon-staging.huggingface.co" if _staging_mode else "https://huggingface.co"
)


HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}"

REPO_TYPE_DATASET = "dataset"
Expand Down
12 changes: 8 additions & 4 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import subprocess
import sys
import warnings
Expand Down Expand Up @@ -80,7 +81,7 @@ def _validate_repo_id_deprecation(repo_id, name, organization):
return name, organization


def repo_type_and_id_from_hf_id(hf_id: str):
def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None):
"""
Returns the repo type and ID from a huggingface.co URL linking to a
repository
Expand All @@ -94,16 +95,19 @@ def repo_type_and_id_from_hf_id(hf_id: str):
- <repo_type>/<namespace>/<repo_id>
- <namespace>/<repo_id>
- <repo_id>
hub_url (`str`, *optional*):
The URL of the HuggingFace Hub, defaults to https://huggingface.co
"""
is_hf_url = "huggingface.co" in hf_id and "@" not in hf_id
hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else ENDPOINT)
is_hf_url = hub_url in hf_id and "@" not in hf_id
url_segments = hf_id.split("/")
is_hf_id = len(url_segments) <= 3

if is_hf_url:
namespace, repo_id = url_segments[-2:]
if namespace == "huggingface.co":
if namespace == hub_url:
namespace = None
if len(url_segments) > 2 and "huggingface.co" not in url_segments[-3]:
if len(url_segments) > 2 and hub_url not in url_segments[-3]:
repo_type = url_segments[-3]
else:
repo_type = None
Expand Down
30 changes: 19 additions & 11 deletions src/huggingface_hub/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
from urllib.parse import urlparse

from tqdm.auto import tqdm

from huggingface_hub.constants import REPO_TYPES_URL_PREFIXES, REPOCARD_NAME
from huggingface_hub.repocard import metadata_load, metadata_save

from .hf_api import ENDPOINT, HfApi, HfFolder, repo_type_and_id_from_hf_id
from .hf_api import HfApi, HfFolder, repo_type_and_id_from_hf_id
from .lfs import LFS_MULTIPART_UPLOAD_COMMAND
from .utils import logging

Expand Down Expand Up @@ -441,6 +442,7 @@ def __init__(
revision: Optional[str] = None,
private: bool = False,
skip_lfs_files: bool = False,
client: Optional[HfApi] = None,
):
"""
Instantiate a local clone of a git repo.
Expand Down Expand Up @@ -482,6 +484,9 @@ def __init__(
whether the repository is private or not.
skip_lfs_files (`bool`, *optional*, defaults to `False`):
whether to skip git-LFS files or not.
client (`HfApi`, *optional*):
Instance of HfApi to use when calling the HF Hub API.
A new instance will be created if this is left to `None`.
"""

os.makedirs(local_dir, exist_ok=True)
Expand All @@ -490,6 +495,7 @@ def __init__(
self.command_queue = []
self.private = private
self.skip_lfs_files = skip_lfs_files
self.client = client if client is not None else HfApi()

self.check_git_versions()

Expand All @@ -513,7 +519,7 @@ def __init__(
if self.huggingface_token is not None and (
git_email is None or git_user is None
):
user = HfApi().whoami(self.huggingface_token)
user = self.client.whoami(self.huggingface_token)

if git_email is None:
git_email = user["email"]
Expand Down Expand Up @@ -631,34 +637,36 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
"Couldn't load Hugging Face Authorization Token. Credentials are required to work with private repositories."
" Please login in using `huggingface-cli login` or provide your token manually with the `use_auth_token` key."
)
api = HfApi()

if "huggingface.co" in repo_url or (
hub_url = self.client.endpoint
if hub_url in repo_url or (
"http" not in repo_url and len(repo_url.split("/")) <= 2
):
repo_type, namespace, repo_id = repo_type_and_id_from_hf_id(repo_url)
repo_type, namespace, repo_id = repo_type_and_id_from_hf_id(
repo_url, hub_url=hub_url
)

if repo_type is not None:
self.repo_type = repo_type

repo_url = ENDPOINT + "/"
repo_url = hub_url + "/"

if self.repo_type in REPO_TYPES_URL_PREFIXES:
repo_url += REPO_TYPES_URL_PREFIXES[self.repo_type]

if token is not None:
whoami_info = api.whoami(token)
whoami_info = self.client.whoami(token)
user = whoami_info["name"]
valid_organisations = [org["name"] for org in whoami_info["orgs"]]

if namespace is not None:
repo_id = f"{namespace}/{repo_id}"
repo_url += repo_id

repo_url = repo_url.replace("https://", f"https://user:{token}@")
scheme = urlparse(repo_url).scheme
repo_url = repo_url.replace(f"{scheme}://", f"{scheme}://user:{token}@")

if namespace == user or namespace in valid_organisations:
api.create_repo(
self.client.create_repo(
repo_id=repo_id,
token=token,
repo_type=self.repo_type,
Expand All @@ -671,7 +679,7 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
repo_url += repo_id

# For error messages, it's cleaner to show the repo url without the token.
clean_repo_url = re.sub(r"https://.*@", "https://", repo_url)
clean_repo_url = re.sub(r"(https?)://.*@", r"\1://", repo_url)
try:
subprocess.run(
"git lfs install".split(),
Expand Down
5 changes: 4 additions & 1 deletion tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,4 +1277,7 @@ def test_repo_type_and_id_from_hf_id(self):
}

for key, value in possible_values.items():
self.assertEqual(repo_type_and_id_from_hf_id(key), tuple(value))
self.assertEqual(
repo_type_and_id_from_hf_id(key, hub_url="https://huggingface.co"),
tuple(value),
)
7 changes: 1 addition & 6 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,7 @@ def with_production_testing(func):
ENDPOINT_PRODUCTION,
)

repository = patch(
"huggingface_hub.repository.ENDPOINT",
ENDPOINT_PRODUCTION,
)

return repository(hf_api(file_download(func)))
return hf_api(file_download(func))


def retry_endpoint(function, number_of_tries: int = 3, wait_time: int = 5):
Expand Down

0 comments on commit 9edcb30

Please sign in to comment.