diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 6848742031..6556f6f146 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -318,13 +318,17 @@ def login(self, username: str, password: str) -> str: write_to_credential_store(username, password) return d["token"] - def whoami(self, token: str) -> Dict: + def whoami(self, token: Optional[str] = None) -> Dict: """ Call HF API to know "whoami". Args: - token (``str``): Hugging Face token. + token (``str``, `optional`): + Hugging Face token. Will default to the locally saved token if not provided. """ + if token is None: + token = HfFolder.get_token() + path = "{}/api/whoami-v2".format(self.endpoint) r = requests.get(path, headers={"authorization": "Bearer {}".format(token)}) try: @@ -797,6 +801,35 @@ def upload_file( d = r.json() return d["url"] + def get_full_repo_name( + self, + model_id: str, + organization: Optional[str] = None, + token: Optional[str] = None, + ): + """ + Returns the repository name for a given model ID and optional organization. + + Args: + model_id (``str``): + The name of the model. + organization (``str``, `optional`): + If passed, the repository name will be in the organization namespace instead of the + user namespace. + token (``str``, `optional`): + The Hugging Face authentication token + + Returns: + ``str``: The repository name in the user's namespace ({username}/{model_id}) if no + organization is passed, and under the organization namespace ({organization}/{model_id}) + otherwise. + """ + if organization is None: + username = self.whoami(token=token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + class HfFolder: path_token = expanduser("~/.huggingface/token") diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 713c4d39bf..389d6c7a22 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -341,6 +341,15 @@ def test_upload_file_conflict(self): finally: self._api.delete_repo(token=self._token, name=REPO_NAME) + def test_get_full_repo_name(self): + repo_name_with_no_org = self._api.get_full_repo_name("model", token=self._token) + self.assertEqual(repo_name_with_no_org, f"{USER}/model") + + repo_name_with_no_org = self._api.get_full_repo_name( + "model", organization="org", token=self._token + ) + self.assertEqual(repo_name_with_no_org, "org/model") + class HfApiPublicTest(unittest.TestCase): def test_staging_list_models(self):