Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable logging in with organization token #780

Merged
merged 11 commits into from
Mar 31, 2022
4 changes: 2 additions & 2 deletions src/huggingface_hub/commands/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ def _login(hf_api, username=None, password=None, token=None):
print(e)
print(ANSI.red(e.response.text))
exit(1)
elif not hf_api._is_valid_token(token):
raise ValueError("Invalid token passed.")
else:
token, name = hf_api._validate_or_retrieve_token(token)

hf_api.set_access_token(token)
HfFolder.save_token(token)
Expand Down
109 changes: 50 additions & 59 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,6 @@ def whoami(self, token: Optional[str] = None) -> Dict:
"You need to pass a valid `token` or login by using `huggingface-cli "
"login`"
)

path = f"{self.endpoint}/api/whoami-v2"
r = requests.get(path, headers={"authorization": f"Bearer {token}"})
try:
Expand Down Expand Up @@ -553,19 +552,21 @@ def _is_valid_token(self, token: str):
except HTTPError:
return False

def _validate_or_retrieve_token(self, token: Optional[Union[str, bool]] = None):
def _validate_or_retrieve_token(
self,
token: Optional[str] = None,
name: Optional[str] = None,
function_name: Optional[str] = None,
):
"""
Either retrieves stored token or validates passed token.

Retrieves and validates stored token or validates passed token.
Args:
token (`str`, *optional*):
The token to check for validity

Returns:
`str`: The valid token

Raises:
`ValueError`: if the token is invalid.
token (``str``, `optional`):
Hugging Face token. Will default to the locally saved token if not provided.
name (``str``, `optional`):
Name of the repository.
function_name (``str``, `optional`):
If called from a function, name of that function for deprecation warning.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Returns and Raises sections.

It would also be nice to explain why name and function_name are required and that they'd be removed in a future version (in v0.7 maybe?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

"""
if token is None or token is True:
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved
token = HfFolder.get_token()
Expand All @@ -574,9 +575,22 @@ def _validate_or_retrieve_token(self, token: Optional[Union[str, bool]] = None):
"You need to provide a `token` or be logged in to Hugging "
"Face with `huggingface-cli login`."
)
elif not self._is_valid_token(token):
raise ValueError("Invalid token passed!")
return token
if name is not None:
if self._is_valid_token(name):
# TODO(0.6) REMOVE
warnings.warn(
merveenoyan marked this conversation as resolved.
Show resolved Hide resolved
f"`{function_name}` now takes `token` as an optional positional argument. "
"Be sure to adapt your code!",
FutureWarning,
)
token, name = name, token
if isinstance(token, str):
if token.startswith("api_org"):
raise ValueError("You must use your personal account token.")
if not self._is_valid_token(token):
raise ValueError("Invalid token passed!")

return token, name
adrinjalali marked this conversation as resolved.
Show resolved Hide resolved

def logout(self, token: Optional[str] = None) -> None:
"""
Expand Down Expand Up @@ -758,7 +772,7 @@ def list_models(
"""
path = f"{self.endpoint}/api/models"
if use_auth_token:
token = self._validate_or_retrieve_token(use_auth_token)
token, name = self._validate_or_retrieve_token(use_auth_token)
headers = {"authorization": f"Bearer {token}"} if use_auth_token else None
params = {}
if filter is not None:
Expand Down Expand Up @@ -955,7 +969,7 @@ def list_datasets(
"""
path = f"{self.endpoint}/api/datasets"
if use_auth_token:
token = self._validate_or_retrieve_token(use_auth_token)
token, name = self._validate_or_retrieve_token(use_auth_token)
headers = {"authorization": f"Bearer {token}"} if use_auth_token else None
params = {}
if filter is not None:
Expand Down Expand Up @@ -1240,18 +1254,10 @@ def create_repo(
name, organization = _validate_repo_id_deprecation(repo_id, name, organization)

path = f"{self.endpoint}/api/repos/create"
if token is None:
token = self._validate_or_retrieve_token()
elif not self._is_valid_token(token):
if self._is_valid_token(name):
warnings.warn(
"`create_repo` now takes `token` as an optional positional argument. "
"Be sure to adapt your code!",
FutureWarning,
)
token, name = name, token
else:
raise ValueError("Invalid token passed!")

token, name = self._validate_or_retrieve_token(
token, name, function_name="create_repo"
)

checked_name = repo_type_and_id_from_hf_id(name)

Expand Down Expand Up @@ -1369,18 +1375,10 @@ def delete_repo(
name, organization = _validate_repo_id_deprecation(repo_id, name, organization)

path = f"{self.endpoint}/api/repos/delete"
if token is None:
token = self._validate_or_retrieve_token()
elif not self._is_valid_token(token):
if self._is_valid_token(name):
warnings.warn(
"`delete_repo` now takes `token` as an optional positional argument. "
"Be sure to adapt your code!",
FutureWarning,
)
token, name = name, token
else:
raise ValueError("Invalid token passed!")

token, name = self._validate_or_retrieve_token(
token, name, function_name="delete_repo"
)

checked_name = repo_type_and_id_from_hf_id(name)

Expand Down Expand Up @@ -1480,18 +1478,9 @@ def update_repo_visibility(

name, organization = _validate_repo_id_deprecation(repo_id, name, organization)

if token is None:
token = self._validate_or_retrieve_token()
elif not self._is_valid_token(token):
if self._is_valid_token(name):
warnings.warn(
"`update_repo_visibility` now takes `token` as an optional positional argument. "
"Be sure to adapt your code!",
FutureWarning,
)
token, name, private = name, private, token
else:
raise ValueError("Invalid token passed!")
token, name = self._validate_or_retrieve_token(
token, name, function_name="update_repo_visibility"
)

if organization is None:
namespace = self.whoami(token)["name"]
Expand Down Expand Up @@ -1548,7 +1537,8 @@ def move_repo(

- [1] https://huggingface.co/settings/tokens
"""
token = self._validate_or_retrieve_token(token)

token, name = self._validate_or_retrieve_token(token)

if len(from_id.split("/")) != 2:
raise ValueError(
Expand Down Expand Up @@ -1664,9 +1654,11 @@ def upload_file(
if repo_type not in REPO_TYPES:
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")

if token is None:
token = self._validate_or_retrieve_token()
elif not self._is_valid_token(token):
try:
token, name = self._validate_or_retrieve_token(
token, function_name="upload_file"
)
except ValueError: # if token is invalid or organization token
if self._is_valid_token(path_or_fileobj):
warnings.warn(
"`upload_file` now takes `token` as an optional positional argument. "
Expand Down Expand Up @@ -1769,8 +1761,7 @@ def delete_file(
if repo_type not in REPO_TYPES:
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")

if token is None:
token = self._validate_or_retrieve_token()
token, name = self._validate_or_retrieve_token(token)

if repo_type in REPO_TYPES_URL_PREFIXES:
repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def push_to_hub(
token = HfFolder.get_token()
if token is None:
raise ValueError(
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
"You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
"token as the `use_auth_token` argument."
)
Expand Down
23 changes: 23 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def test_login_cli(self):
read_from_credential_store(USERNAME_PLACEHOLDER), (None, None)
)

def test_login_cli_org_fail(self):
with pytest.raises(
ValueError, match="You must use your personal account token."
):
_login(self._api, token="api_org_dummy_token")

def test_login_deprecation_error(self):
with pytest.warns(
FutureWarning,
Expand Down Expand Up @@ -561,6 +567,23 @@ def test_upload_file_bytesio(self):
finally:
self._api.delete_repo(repo_id=REPO_NAME, token=self._token)

@retry_endpoint
def test_create_repo_org_token_fail(self):
REPO_NAME = repo_name("org")
with pytest.raises(
ValueError, match="You must use your personal account token."
):
self._api.create_repo(repo_id=REPO_NAME, token="api_org_dummy_token")

@retry_endpoint
def test_create_repo_org_token_none_fail(self):
REPO_NAME = repo_name("org")
HfFolder.save_token("api_org_dummy_token")
with pytest.raises(
ValueError, match="You must use your personal account token."
):
self._api.create_repo(repo_id=REPO_NAME)

@retry_endpoint
def test_upload_file_conflict(self):
REPO_NAME = repo_name("conflict")
Expand Down