Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Tag Listing #537

Merged
merged 26 commits into from
Dec 20, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
REPO_TYPES_URL_PREFIXES,
SPACES_SDK_TYPES,
)
from .utils.tags import DatasetTags, ModelTags


if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -417,6 +418,22 @@ def set_access_token(access_token: str):
def unset_access_token():
erase_from_credential_store(USERNAME_PLACEHOLDER)

def get_model_tags(self) -> ModelTags:
"Gets all valid model tags as a nested namespace object"
path = f"{api.endpoint}/api/models-tags-by-type"
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
r = requests.get(path)
r.raise_for_status()
d = r.json()
return ModelTags(d)

def get_dataset_tags(self) -> DatasetTags:
"Gets all valid dataset tags as a nested namespace object"
path = f"{api.endpoint}/api/datasets-tags-by-type"
r = requests.get(path)
r.raise_for_status()
d = r.json()
return DatasetTags(d)

muellerzr marked this conversation as resolved.
Show resolved Hide resolved
def list_models(
self,
filter: Union[str, Iterable[str], None] = None,
Expand Down
114 changes: 114 additions & 0 deletions src/huggingface_hub/utils/tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tagging utilities. """
muellerzr marked this conversation as resolved.
Show resolved Hide resolved


class AttributeDictionary(dict):
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
"""`dict` subclass that also provides access to keys as attributes

Example usage:

>>> d = AttributeDictionary()
>>> d["test"] = "a"
>>> print(d.test) # prints "a"

"""

def __getattr__(self, k):
if k in self:
return self[k]
else:
raise AttributeError(k)

def __setattr__(self, k, v):
(self.__setitem__, super().__setattr__)[k[0] == "_"](k, v)

def __delattr__(self, k):
if k in self:
del self[k]
else:
raise AttributeError(k)

def __dir__(self):
return super().__dir__() + list(self.keys())

def __repr__(self):
_ignore = [str(o) for o in dir(AttributeDictionary())]
repr_str = "Available Attributes:\n"
for o in dir(self):
if (o not in _ignore) and not (o.startswith("_")):
repr_str += f" * {o}\n"
return repr_str


class GeneralTags(AttributeDictionary):
"""
A namespace object holding all model tags, filtered by `keys`

Args:
tag_dictionary (``dict``):
A dictionary of tags returned from the /api/***-tags-by-type api endpoint
keys (``list``):
A list of keys to unpack the `tag_dictionary` with, such as `["library","language"]`
"""

def __init__(self, tag_dictionary: dict, keys: list = None):
self._tag_dictionary = tag_dictionary
if keys is None:
keys = list(self._tag_dictionary.keys())
for key in keys:
self._unpack_and_assign_dictionary(key)

def _unpack_and_assign_dictionary(self, key: str):
"Assignes nested attr to `self.key` containing information as an `AttrDict`"
setattr(self, key, AttributeDictionary())
for item in self._tag_dictionary[key]:
ref = getattr(self, key)
item["label"] = item["label"].replace(" ", "").replace("-", "_")
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
setattr(ref, item["label"], item["id"])


class ModelTags(GeneralTags):
"""
A namespace object holding all available model tags

Args:
model_tag_dictionary (``dict``):
A dictionary of valid model tags, returned from the /api/models-tags-by-type api endpoint
"""

def __init__(self, model_tag_dictionary: dict):
keys = ["library", "language", "license", "dataset", "pipeline_tag"]
super().__init__(model_tag_dictionary, keys)


class DatasetTags(GeneralTags):
"""
A namespace object holding all available dataset tags

Args:
dataset_tag_dictionary (``dict``):
A dictionary of valid dataset tags, returned from the /api/datasets-tags-by-type api endpoint
"""

def __init__(self, dataset_tag_dictionary: dict):
keys = [
"languages",
"multilinguality",
"language_creators",
"task_categories",
"size_categories",
"benchmark",
"task_ids",
"licenses",
]
super().__init__(dataset_tag_dictionary, keys)
149 changes: 149 additions & 0 deletions tests/test_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import requests
from huggingface_hub.hf_api import HfApi
from huggingface_hub.utils.tags import (
AttributeDictionary,
DatasetTags,
GeneralTags,
ModelTags,
)


class AttributeDictionaryCommonTest(unittest.TestCase):
_attrdict = AttributeDictionary()


class AttributeDictionaryTest(AttributeDictionaryCommonTest):
def test_adding_item(self):
self._attrdict["itemA"] = 2
self.assertEqual(self._attrdict.itemA, 2)
self.assertEqual(self._attrdict["itemA"], 2)
# We should be able to both set a property and a key
self._attrdict.itemB = 3
self.assertEqual(self._attrdict.itemB, 3)
self.assertEqual(self._attrdict["itemB"], 3)

def test_removing_item(self):
self._attrdict["itemA"] = 2
self._attrdict.itemB = 3
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
delattr(self._attrdict, "itemA")
with self.assertRaises(KeyError):
_ = self._attrdict["itemA"]

del self._attrdict["itemB"]
with self.assertRaises(AttributeError):
_ = self._attrdict.itemB

def test_dir(self):
# Since we subclass dict, dir should have everything
# from dict and the atttributes
_dict_keys = dir(dict) + [
"__dict__",
"__getattr__",
"__module__",
"__weakref__",
]
self._attrdict["itemA"] = 2
self._attrdict.itemB = 3
_dict_keys += ["itemA", "itemB"]
_dict_keys.sort()

full_dir = dir(self._attrdict)
full_dir.sort()
self.assertEqual(full_dir, _dict_keys)

def test_repr(self):
self._attrdict["itemA"] = 2
self._attrdict.itemB = 3
repr_string = "Available Attributes:\n * itemA\n * itemB\n"
self.assertEqual(repr_string, repr(self._attrdict))


class GeneralTagsCommonTest(unittest.TestCase):
# Similar to the output from /api/***-tags-by-type
# id = how we can search hfapi, such as `'id': 'languages:en'`
# label = A human readable version assigned to everything, such as `"label":"en"`
_tag_dictionary = {
"languages": [
{"id": "itemA", "label": "Item A"},
{"id": "itemB", "label": "Item-B"},
],
"license": [
{"id": "itemC", "label": "Item C"},
{"id": "itemD", "label": "Item-D"},
],
}
muellerzr marked this conversation as resolved.
Show resolved Hide resolved


class GeneralTagsTest(GeneralTagsCommonTest):
def test_init(self):
_tags = GeneralTags(self._tag_dictionary)
self.assertTrue(all(hasattr(_tags, kind) for kind in ["languages", "license"]))
languages = getattr(_tags, "languages")
licenses = getattr(_tags, "license")
# Ensure they have the right bits

self.assertEqual(
languages,
AttributeDictionary({"ItemA": "itemA", "Item_B": "itemB"}),
)
self.assertEqual(
licenses, AttributeDictionary({"ItemC": "itemC", "Item_D": "itemD"})
)

def test_filter(self):
_tags = GeneralTags(self._tag_dictionary, keys=["license"])
self.assertTrue(hasattr(_tags, "license"))
with self.assertRaises(AttributeError):
_ = getattr(_tags, "languages")
self.assertEqual(
_tags.license, AttributeDictionary({"ItemC": "itemC", "Item_D": "itemD"})
)


class ModelTagsTest(unittest.TestCase):
def test_tags(self):
_api = HfApi()
path = f"{_api.endpoint}/api/models-tags-by-type"
r = requests.get(path)
r.raise_for_status()
d = r.json()
o = ModelTags(d)
muellerzr marked this conversation as resolved.
Show resolved Hide resolved
for kind in ["library", "language", "license", "dataset", "pipeline_tag"]:
self.assertTrue(len(getattr(o, kind).keys()) > 0)


class DatasetTagsTest(unittest.TestCase):
def test_tags(self):
_api = HfApi()
path = f"{_api.endpoint}/api/datasets-tags-by-type"
r = requests.get(path)
r.raise_for_status()
d = r.json()
o = DatasetTags(d)
for kind in [
"languages",
"multilinguality",
"language_creators",
"task_categories",
"size_categories",
"benchmark",
"task_ids",
"licenses",
]:
self.assertTrue(len(getattr(o, kind).keys()) > 0)