Skip to content

Commit

Permalink
Repository fixes (#213)
Browse files Browse the repository at this point in the history
* Pass the `track_large_files` argument

* Test non-hf repo

* Authorize non-hf git clone

style

* Style
  • Loading branch information
LysandreJik authored Jul 22, 2021
1 parent 0097fc1 commit b6d03ce
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 27 deletions.
57 changes: 30 additions & 27 deletions src/huggingface_hub/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,39 +223,42 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non
token = use_auth_token if use_auth_token is not None else self.huggingface_token
api = HfApi()

repo_type, namespace, repo_id = repo_type_and_id_from_hf_id(repo_url)
if "huggingface.co" in repo_url or (
"http" not in repo_url and len(repo_url.split("/")) <= 2
):
repo_type, namespace, repo_id = repo_type_and_id_from_hf_id(repo_url)

if repo_type is not None:
self.repo_type = repo_type
if repo_type is not None:
self.repo_type = repo_type

repo_url = ENDPOINT + "/"
repo_url = ENDPOINT + "/"

if self.repo_type in REPO_TYPES_URL_PREFIXES:
repo_url += REPO_TYPES_URL_PREFIXES[self.repo_type]
if self.repo_type in REPO_TYPES_URL_PREFIXES:
repo_url += REPO_TYPES_URL_PREFIXES[self.repo_type]

if token is not None:
whoami_info = api.whoami(token)
user = whoami_info["name"]
valid_organisations = [org["name"] for org in whoami_info["orgs"]]
if token is not None:
whoami_info = api.whoami(token)
user = whoami_info["name"]
valid_organisations = [org["name"] for org in whoami_info["orgs"]]

if namespace is not None:
repo_url += f"{namespace}/"
repo_url += repo_id
if namespace is not None:
repo_url += f"{namespace}/"
repo_url += repo_id

repo_url = repo_url.replace("https://", f"https://user:{token}@")
repo_url = repo_url.replace("https://", f"https://user:{token}@")

if namespace == user or namespace in valid_organisations:
api.create_repo(
token,
repo_id,
repo_type=self.repo_type,
organization=namespace,
exist_ok=True,
)
else:
if namespace is not None:
repo_url += f"{namespace}/"
repo_url += repo_id
if namespace == user or namespace in valid_organisations:
api.create_repo(
token,
repo_id,
repo_type=self.repo_type,
organization=namespace,
exist_ok=True,
)
else:
if namespace is not None:
repo_url += f"{namespace}/"
repo_url += repo_id

# For error messages, it's cleaner to show the repo url without the token.
clean_repo_url = re.sub(r"https://.*@", "https://", repo_url)
Expand Down Expand Up @@ -659,7 +662,7 @@ def commit(self, commit_message: str, track_large_files: bool = True):
try:
yield self
finally:
self.git_add(auto_lfs_track=True)
self.git_add(auto_lfs_track=track_large_files)

try:
self.git_commit(commit_message)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,13 @@ def test_clone_with_repo_name_org_and_no_auth_token(self):
git_email="[email protected]",
)

def test_clone_not_hf_url(self):
# Should not error out
Repository(
f"{WORKING_REPO_DIR}/{REPO_NAME}",
clone_from="https://hf.co/hf-internal-testing/huggingface-hub-dummy-repository",
)

@with_production_testing
def test_clone_repo_at_root(self):
os.environ["GIT_LFS_SKIP_SMUDGE"] = "1"
Expand Down

0 comments on commit b6d03ce

Please sign in to comment.