diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000000..1430c20a18
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,2 @@
+include src/huggingface_hub/templates/modelcard_template.md
+include src/huggingface_hub/templates/datasetcard_template.md
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 150e18ab17..3ab66c2a8f 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -19,6 +19,8 @@
title: Interact with Discussions and Pull Requests
- local: how-to-cache
title: Manage the Cache
+ - local: how-to-model-cards
+ title: Create and Share Model Cards
title: "Guides"
- sections:
- local: package_reference/repository
@@ -37,4 +39,6 @@
title: Discussions and Pull Requests
- local: package_reference/cache
title: Cache-system reference
+ - local: package_reference/cards
+ title: Repo Cards and Repo Card Data
title: "Reference"
\ No newline at end of file
diff --git a/docs/source/how-to-model-cards.mdx b/docs/source/how-to-model-cards.mdx
new file mode 100644
index 0000000000..b380642aee
--- /dev/null
+++ b/docs/source/how-to-model-cards.mdx
@@ -0,0 +1,315 @@
+# Creating and Sharing Model Cards
+
+The `huggingface_hub` library provides a Python interface to create, share, and update Model Cards.
+Visit [the dedicated documentation page](https://huggingface.co/docs/hub/models-cards)
+for a deeper view of what Model Cards on the Hub are, and how they work under the hood.
+
+## Loading a Model Card from the Hub
+
+To load an existing card from the Hub, you can use the [`ModelCard.load`] function. Here, we'll load the card from [`nateraw/vit-base-beans`](https://huggingface.co/nateraw/vit-base-beans).
+
+```python
+from huggingface_hub import ModelCard
+
+card = ModelCard.load('nateraw/vit-base-beans')
+```
+
+This card has some helpful attributes that you may want to access/leverage:
+ - `card.data`: Returns a [`ModelCardData`] instance with the model card's metadata. Call `.to_dict()` on this instance to get the representation as a dictionary.
+ - `card.text`: Returns the text of the card, *excluding the metadata header*.
+ - `card.content`: Returns the text content of the card, *including the metadata header*.
+
+## Creating Model Cards
+
+### From Text
+
+To initialize a Model Card from text, just pass the text content of the card to the `ModelCard` on init.
+
+```python
+content = """
+---
+language: en
+license: mit
+---
+
+# My Model Card
+"""
+
+card = ModelCard(content)
+card.data.to_dict() == {'language': 'en', 'license': 'mit'} # True
+```
+
+Another way you might want to do this is with f-strings. In the following example, we:
+
+- Use [`ModelCardData.to_yaml`] to convert metadata we defined to YAML so we can use it to insert the YAML block in the model card.
+- Show how you might use a template variable via Python f-strings.
+
+```python
+card_data = ModelCardData(language='en', license='mit', library='timm')
+
+example_template_var = 'nateraw'
+content = f"""
+---
+{ card_data.to_yaml() }
+---
+
+# My Model Card
+
+This model created by [@{example_template_var}](https://github.com/{example_template_var})
+"""
+
+card = ModelCard(content)
+print(card)
+```
+
+The above example would leave us with a card that looks like this:
+
+```
+---
+language: en
+license: mit
+library: timm
+---
+
+# My Model Card
+
+This model created by [@nateraw](https://github.com/nateraw)
+```
+
+### From a Jinja Template
+
+If you have `Jinja2` installed, you can create Model Cards from a jinja template file. Let's see a basic example:
+
+```python
+from pathlib import Path
+
+from huggingface_hub import ModelCard, ModelCardData
+
+# Define your jinja template
+template_text = """
+---
+{{ card_data }}
+---
+
+# Model Card for MyCoolModel
+
+This model does this and that.
+
+This model was created by [@{{ author }}](https://hf.co/{{author}}).
+""".strip()
+
+# Write the template to a file
+Path('custom_template.md').write_text(template_text)
+
+# Define card metadata
+card_data = ModelCardData(language='en', license='mit', library_name='keras')
+
+# Create card from template, passing it any jinja template variables you want.
+# In our case, we'll pass author
+card = ModelCard.from_template(card_data, template_path='custom_template.md', author='nateraw')
+card.save('my_model_card_1.md')
+print(card)
+```
+
+The resulting card's markdown looks like this:
+
+```
+---
+language: en
+license: mit
+library_name: keras
+---
+
+# Model Card for MyCoolModel
+
+This model does this and that.
+
+This model was created by [@nateraw](https://hf.co/nateraw).
+```
+
+If you update any card.data, it'll reflect in the card itself.
+
+```
+card.data.library_name = 'timm'
+card.data.language = 'fr'
+card.data.license = 'apache-2.0'
+print(card)
+```
+
+Now, as you can see, the metadata header has been updated:
+
+```
+---
+language: fr
+license: apache-2.0
+library_name: timm
+---
+
+# Model Card for MyCoolModel
+
+This model does this and that.
+
+This model was created by [@nateraw](https://hf.co/nateraw).
+```
+
+As you update the card data, you can validate the card is still valid against the Hub by calling [`ModelCard.validate`]. This ensures that the card passes any validation rules set up on the Hugging Face Hub.
+
+### From the Default Template
+
+Instead of using your own template, you can also use the [default template](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md), which is a fully featured model card with tons of sections you may want to fill out. Under the hood, it uses [Jinja2](https://jinja.palletsprojects.com/en/3.1.x/) to fill out a template file.
+
+
+
+Note that you will have to have Jinja2 installed to use `from_template`. You can do so with `pip install Jinja2`.
+
+
+
+```python
+card_data = ModelCardData(language='en', license='mit', library_name='keras')
+card = ModelCard.from_template(
+ card_data,
+ model_id='my-cool-model',
+ model_description="this model does this and that",
+ developers="Nate Raw",
+ more_resources="https://github.com/huggingface/huggingface_hub",
+)
+card.save('my_model_card_2.md')
+print(card)
+```
+
+## Sharing Model Cards
+
+If you're authenticated with the Hugging Face Hub (either by using `huggingface-cli login` or `huggingface_hub.notebook_login()`), you can push cards to the Hub by simply calling [`ModelCard.push_to_hub`]. Let's take a look at how to do that...
+
+First, we'll create a new repo called 'hf-hub-modelcards-pr-test' under the authenticated user's namespace:
+
+```python
+from huggingface_hub import whoami, create_repo
+
+user = whoami()['name']
+repo_id = f'{user}/hf-hub-modelcards-pr-test'
+url = create_repo(repo_id, exist_ok=True)
+```
+
+Then, we'll create a card from the default template (same as the one defined in the section above):
+
+```python
+card_data = ModelCardData(language='en', license='mit', library_name='keras')
+card = ModelCard.from_template(
+ card_data,
+ model_id='my-cool-model',
+ model_description="this model does this and that",
+ developers="Nate Raw",
+ more_resources="https://github.com/huggingface/huggingface_hub",
+)
+```
+
+Finally, we'll push that up to the hub
+
+```python
+card.push_to_hub(repo_id)
+```
+
+You can check out the resulting card [here](https://huggingface.co/nateraw/hf-hub-modelcards-pr-test/blob/main/README.md).
+
+If you instead wanted to push a card as a pull request, you can just say `create_pr=True` when calling `push_to_hub`:
+
+```python
+card.push_to_hub(repo_id, create_pr=True)
+```
+
+A resulting PR created from this command can be seen [here](https://huggingface.co/nateraw/hf-hub-modelcards-pr-test/discussions/3).
+
+### Including Evaluation Results
+
+To include evaluation results in the metadata `model-index`, you can pass an [`EvalResult`] or a list of `EvalResult` with your associated evaluation results. Under the hood it'll create the `model-index` when you call `card.data.to_dict()`. For more information on how this works, you can check out [this section of the Hub docs](https://huggingface.co/docs/hub/models-cards#evaluation-results).
+
+
+
+Note that using this function requires you to include the `model_name` attribute in [`ModelCardData`].
+
+
+
+```python
+card_data = ModelCardData(
+ language='en',
+ license='mit',
+ model_name='my-cool-model',
+ eval_results = EvalResult(
+ task_type='image-classification',
+ dataset_type='beans',
+ dataset_name='Beans',
+ metric_type='accuracy',
+ metric_value=0.7
+ )
+)
+
+card = ModelCard.from_template(card_data)
+print(card.data)
+```
+
+The resulting `card.data` should look like this:
+
+```
+language: en
+license: mit
+model-index:
+- name: my-cool-model
+ results:
+ - task:
+ type: image-classification
+ dataset:
+ name: Beans
+ type: beans
+ metrics:
+ - type: accuracy
+ value: 0.7
+```
+
+If you have more than one evaluation result you'd like to share, just pass a list of `EvalResult`:
+
+```python
+card_data = ModelCardData(
+ language='en',
+ license='mit',
+ model_name='my-cool-model',
+ eval_results = [
+ EvalResult(
+ task_type='image-classification',
+ dataset_type='beans',
+ dataset_name='Beans',
+ metric_type='accuracy',
+ metric_value=0.7
+ ),
+ EvalResult(
+ task_type='image-classification',
+ dataset_type='beans',
+ dataset_name='Beans',
+ metric_type='f1',
+ metric_value=0.65
+ )
+ ]
+)
+card = ModelCard.from_template(card_data)
+card.data
+```
+
+Which should leave you with the following `card.data`:
+
+```
+language: en
+license: mit
+model-index:
+- name: my-cool-model
+ results:
+ - task:
+ type: image-classification
+ dataset:
+ name: Beans
+ type: beans
+ metrics:
+ - type: accuracy
+ value: 0.7
+ - type: f1
+ value: 0.65
+```
\ No newline at end of file
diff --git a/docs/source/package_reference/cards.mdx b/docs/source/package_reference/cards.mdx
new file mode 100644
index 0000000000..851777f5af
--- /dev/null
+++ b/docs/source/package_reference/cards.mdx
@@ -0,0 +1,62 @@
+# Repository Cards
+
+The huggingface_hub library provides a Python interface to create, share, and update Model/Dataset Cards.
+Visit the [dedicated documentation page](https://huggingface.co/docs/hub/models-cards) for a deeper view of what
+Model Cards on the Hub are, and how they work under the hood. You can also check out our [Model Cards guide](../how-to-model-cards) to
+get a feel for how you would use these utilities in your own projects.
+
+## Repo Card
+
+The `RepoCard` object is the parent class of [`ModelCard`] and [`DatasetCard`].
+
+[[autodoc]] huggingface_hub.repocard.RepoCard
+ - __init__
+ - all
+## Card Data
+
+The [`CardData`] object is the parent class of [`ModelCardData`] and [`DatasetCardData`].
+
+[[autodoc]] huggingface_hub.repocard_data.CardData
+
+## Model Cards
+### ModelCard
+
+[[autodoc]] ModelCard
+
+### ModelCardData
+
+[[autodoc]] ModelCardData
+
+## Dataset Cards
+
+Dataset cards are also known as Data Cards in the ML Community.
+
+### DatasetCard
+
+[[autodoc]] DatasetCard
+
+### DatasetCardData
+
+[[autodoc]] DatasetCardData
+
+## Utilities
+
+### EvalResult
+
+[[autodoc]] EvalResult
+
+### model_index_to_eval_results
+
+[[autodoc]] huggingface_hub.repocard_data.model_index_to_eval_results
+
+### eval_results_to_model_index
+
+[[autodoc]] huggingface_hub.repocard_data.eval_results_to_model_index
+
+### metadata_eval_result
+
+[[autodoc]] huggingface_hub.repocard.metadata_eval_result
+
+### metadata_update
+
+[[autodoc]] huggingface_hub.repocard.metadata_update
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 0245e5113b..e310583370 100644
--- a/setup.py
+++ b/setup.py
@@ -40,6 +40,7 @@ def get_version() -> str:
"pytest-cov",
"datasets",
"soundfile",
+ "Jinja2",
]
extras["quality"] = [
@@ -90,4 +91,5 @@ def get_version() -> str:
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
+ include_package_data=True,
)
diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py
index 26ce55c74d..55b9ff72a5 100644
--- a/src/huggingface_hub/__init__.py
+++ b/src/huggingface_hub/__init__.py
@@ -214,6 +214,8 @@ def __dir__():
"metadata_load",
"metadata_save",
"metadata_update",
+ "ModelCard",
+ "DatasetCard",
],
"community": [
"Discussion",
@@ -224,5 +226,6 @@ def __dir__():
"DiscussionCommit",
"DiscussionTitleChange",
],
+ "repocard_data": ["CardData", "ModelCardData", "DatasetCardData", "EvalResult"],
},
)
diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py
index f6eaf0f802..c5cf2c386b 100644
--- a/src/huggingface_hub/file_download.py
+++ b/src/huggingface_hub/file_download.py
@@ -122,6 +122,14 @@ def is_graphviz_available():
except importlib_metadata.PackageNotFoundError:
pass
+_jinja_version = "N/A"
+_jinja_available = False
+try:
+ _jinja_version: str = importlib_metadata.version("Jinja2")
+ _jinja_available = True
+except importlib_metadata.PackageNotFoundError:
+ pass
+
def is_torch_available():
return _torch_available
@@ -151,6 +159,14 @@ def get_fastcore_version():
return _fastcore_version
+def is_jinja_available():
+ return _jinja_available
+
+
+def get_jinja_version():
+ return _jinja_version
+
+
REGEX_COMMIT_HASH = re.compile(r"^[0-9a-f]{40}$")
diff --git a/src/huggingface_hub/repocard.py b/src/huggingface_hub/repocard.py
index 49ac41d4a4..70b6adb750 100644
--- a/src/huggingface_hub/repocard.py
+++ b/src/huggingface_hub/repocard.py
@@ -1,30 +1,481 @@
-import dataclasses
import os
import re
-import shutil
+import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional, Union
+
+if sys.version_info >= (3, 8):
+ from typing import Literal
+else:
+ from typing_extensions import Literal
+
+import requests
import yaml
-from huggingface_hub.file_download import hf_hub_download
-from huggingface_hub.hf_api import HfApi
-from huggingface_hub.repocard_types import (
- ModelIndex,
- SingleMetric,
- SingleResult,
- SingleResultDataset,
- SingleResultTask,
+from huggingface_hub.file_download import hf_hub_download, is_jinja_available
+from huggingface_hub.hf_api import upload_file
+from huggingface_hub.repocard_data import (
+ CardData,
+ DatasetCardData,
+ EvalResult,
+ ModelCardData,
+ eval_results_to_model_index,
+ model_index_to_eval_results,
)
from .constants import REPOCARD_NAME
+from .utils.logging import get_logger
+TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md"
+TEMPLATE_DATASETCARD_PATH = (
+ Path(__file__).parent / "templates" / "datasetcard_template.md"
+)
+
# exact same regex as in the Hub server. Please keep in sync.
REGEX_YAML_BLOCK = re.compile(r"---[\n\r]+([\S\s]*?)[\n\r]+---[\n\r]")
-UNIQUE_RESULT_FEATURES = ["dataset", "task"]
-UNIQUE_METRIC_FEATURES = ["name", "type"]
+logger = get_logger(__name__)
+
+
+class RepoCard:
+
+ card_data_class = CardData
+ default_template_path = TEMPLATE_MODELCARD_PATH
+ repo_type = "model"
+
+ def __init__(self, content: str):
+ """Initialize a RepoCard from string content. The content should be a
+ Markdown file with a YAML block at the beginning and a Markdown body.
+
+ Args:
+ content (`str`): The content of the Markdown file.
+
+ Example:
+ ```python
+ >>> from huggingface_hub.repocard import RepoCard
+ >>> text = '''
+ ... ---
+ ... language: en
+ ... license: mit
+ ... ---
+ ...
+ ... # My repo
+ ... '''
+ >>> card = RepoCard(text)
+ >>> card.data.to_dict()
+ {'language': 'en', 'license': 'mit'}
+ >>> card.text
+ '\\n# My repo\\n'
+
+ ```
+
+ Raises the following error:
+
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
+ when the content of the repo card metadata is not a dictionary.
+
+
+ """
+ self.content = content
+ match = REGEX_YAML_BLOCK.search(content)
+ if match:
+ # Metadata found in the YAML block
+ yaml_block = match.group(1)
+ self.text = content[match.end() :]
+ data_dict = yaml.safe_load(yaml_block)
+
+ # The YAML block's data should be a dictionary
+ if not isinstance(data_dict, dict):
+ raise ValueError("repo card metadata block should be a dict")
+ else:
+ # Model card without metadata... create empty metadata
+ logger.warning(
+ "Repo card metadata block was not found. Setting CardData to empty."
+ )
+ data_dict = {}
+ self.text = content
+
+ self.data = self.card_data_class(**data_dict)
+
+ def __str__(self):
+ line_break = _detect_line_ending(self.content) or "\n"
+ return f"---{line_break}{self.data.to_yaml(line_break=line_break)}{line_break}---{line_break}{self.text}"
+
+ def save(self, filepath: Union[Path, str]):
+ r"""Save a RepoCard to a file.
+
+ Args:
+ filepath (`Union[Path, str]`): Filepath to the markdown file to save.
+
+ Example:
+ ```python
+ >>> from huggingface_hub.repocard import RepoCard
+ >>> card = RepoCard("---\nlanguage: en\n---\n# This is a test repo card")
+ >>> card.save("/tmp/test.md")
+
+ ```
+ """
+ filepath = Path(filepath)
+ filepath.parent.mkdir(parents=True, exist_ok=True)
+ filepath.write_text(str(self))
+
+ @classmethod
+ def load(
+ cls,
+ repo_id_or_path: Union[str, Path],
+ repo_type: Optional[str] = None,
+ token: Optional[str] = None,
+ ):
+ """Initialize a RepoCard from a Hugging Face Hub repo's README.md or a local filepath.
+
+ Args:
+ repo_id_or_path (`Union[str, Path]`):
+ The repo ID associated with a Hugging Face Hub repo or a local filepath.
+ repo_type (`str`, *optional*):
+ The type of Hugging Face repo to push to. Defaults to None, which will use
+ use "model". Other options are "dataset" and "space". Not used when loading from
+ a local filepath. If this is called from a child class, the default value will be
+ the child class's `repo_type`.
+ token (`str`, *optional*):
+ Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to
+ the stored token.
+
+ Returns:
+ [`huggingface_hub.repocard.RepoCard`]: The RepoCard (or subclass) initialized from the repo's
+ README.md file or filepath.
+
+ Example:
+ ```python
+ >>> from huggingface_hub.repocard import RepoCard
+ >>> card = RepoCard.load("nateraw/food")
+ >>> assert card.data.tags == ["generated_from_trainer", "image-classification", "pytorch"]
+
+ ```
+ """
+
+ if Path(repo_id_or_path).exists():
+ card_path = Path(repo_id_or_path)
+ else:
+ card_path = hf_hub_download(
+ repo_id_or_path,
+ REPOCARD_NAME,
+ repo_type=repo_type or cls.repo_type,
+ use_auth_token=token,
+ )
+
+ # Preserve newlines in the existing file.
+ with Path(card_path).open(mode="r", newline="") as f:
+ return cls(f.read())
+
+ def validate(self, repo_type: Optional[str] = None):
+ """Validates card against Hugging Face Hub's card validation logic.
+ Using this function requires access to the internet, so it is only called
+ internally by [`huggingface_hub.repocard.RepoCard.push_to_hub`].
+
+ Args:
+ repo_type (`str`, *optional*, defaults to "model"):
+ The type of Hugging Face repo to push to. Options are "model", "dataset", and "space".
+ If this function is called from a child class, the default will be the child class's `repo_type`.
+
+
+ Raises the following errors:
+
+ - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
+ if the card fails validation checks.
+ - [`HTTPError`](https://2.python-requests.org/en/master/api/#requests.HTTPError)
+ if the request to the Hub API fails for any other reason.
+
+
+ """
+
+ # If repo type is provided, otherwise, use the repo type of the card.
+ repo_type = repo_type or self.repo_type
+
+ body = {
+ "repoType": repo_type,
+ "content": str(self),
+ }
+ headers = {"Accept": "text/plain"}
+
+ try:
+ r = requests.post(
+ "https://huggingface.co/api/validate-yaml", body, headers=headers
+ )
+ r.raise_for_status()
+ except requests.exceptions.HTTPError as exc:
+ if r.status_code == 400:
+ raise ValueError(r.text)
+ else:
+ raise exc
+
+ def push_to_hub(
+ self,
+ repo_id: str,
+ token: Optional[str] = None,
+ repo_type: Optional[str] = None,
+ commit_message: Optional[str] = None,
+ commit_description: Optional[str] = None,
+ revision: Optional[str] = None,
+ create_pr: Optional[bool] = None,
+ parent_commit: Optional[str] = None,
+ ):
+ """Push a RepoCard to a Hugging Face Hub repo.
+
+ Args:
+ repo_id (`str`):
+ The repo ID of the Hugging Face Hub repo to push to. Example: "nateraw/food".
+ token (`str`, *optional*):
+ Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to
+ the stored token.
+ repo_type (`str`, *optional*, defaults to "model"):
+ The type of Hugging Face repo to push to. Options are "model", "dataset", and "space". If this
+ function is called by a child class, it will default to the child class's `repo_type`.
+ commit_message (`str`, *optional*):
+ The summary / title / first line of the generated commit.
+ commit_description (`str`, *optional*)
+ The description of the generated commit.
+ revision (`str`, *optional*):
+ The git revision to commit from. Defaults to the head of the `"main"` branch.
+ create_pr (`bool`, *optional*):
+ Whether or not to create a Pull Request with this commit. Defaults to `False`.
+ parent_commit (`str`, *optional*):
+ The OID / SHA of the parent commit, as a hexadecimal string. Shorthands (7 first characters) are also supported.
+ If specified and `create_pr` is `False`, the commit will fail if `revision` does not point to `parent_commit`.
+ If specified and `create_pr` is `True`, the pull request will be created from `parent_commit`.
+ Specifying `parent_commit` ensures the repo has not changed before committing the changes, and can be
+ especially useful if the repo is updated / committed to concurrently.
+ Returns:
+ `str`: URL of the commit which updated the card metadata.
+ """
+
+ # If repo type is provided, otherwise, use the repo type of the card.
+ repo_type = repo_type or self.repo_type
+
+ # Validate card before pushing to hub
+ self.validate(repo_type=repo_type)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmp_path = Path(tmpdir) / REPOCARD_NAME
+ tmp_path.write_text(str(self))
+ url = upload_file(
+ path_or_fileobj=str(tmp_path),
+ path_in_repo=REPOCARD_NAME,
+ repo_id=repo_id,
+ token=token,
+ repo_type=repo_type,
+ commit_message=commit_message,
+ commit_description=commit_description,
+ create_pr=create_pr,
+ revision=revision,
+ parent_commit=parent_commit,
+ )
+ return url
+
+ @classmethod
+ def from_template(
+ cls,
+ card_data: CardData,
+ template_path: Optional[str] = None,
+ **template_kwargs,
+ ):
+ """Initialize a RepoCard from a template. By default, it uses the default template.
+
+ Templates are Jinja2 templates that can be customized by passing keyword arguments.
+
+ Args:
+ card_data (`huggingface_hub.CardData`):
+ A huggingface_hub.CardData instance containing the metadata you want to include in the YAML
+ header of the repo card on the Hugging Face Hub.
+ template_path (`str`, *optional*):
+ A path to a markdown file with optional Jinja template variables that can be filled
+ in with `template_kwargs`. Defaults to the default template.
+
+ Returns:
+ [`huggingface_hub.repocard.RepoCard`]: A RepoCard instance with the specified card data and content from the
+ template.
+ """
+ if is_jinja_available():
+ import jinja2
+ else:
+ raise ImportError(
+ "Using RepoCard.from_template requires Jinja2 to be installed. Please"
+ " install it with `pip install Jinja2`."
+ )
+
+ template_path = template_path or cls.default_template_path
+ kwargs = card_data.to_dict().copy()
+ kwargs.update(template_kwargs) # Template_kwargs have priority
+ content = jinja2.Template(Path(template_path).read_text()).render(
+ card_data=card_data.to_yaml(), **kwargs
+ )
+ return cls(content)
+
+
+class ModelCard(RepoCard):
+ card_data_class = ModelCardData
+ default_template_path = TEMPLATE_MODELCARD_PATH
+ repo_type = "model"
+
+ @classmethod
+ def from_template(
+ cls,
+ card_data: ModelCardData,
+ template_path: Optional[str] = None,
+ **template_kwargs,
+ ):
+ """Initialize a ModelCard from a template. By default, it uses the default template, which can be found here:
+ https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/modelcard_template.md
+
+ Templates are Jinja2 templates that can be customized by passing keyword arguments.
+
+ Args:
+ card_data (`huggingface_hub.ModelCardData`):
+ A huggingface_hub.ModelCardData instance containing the metadata you want to include in the YAML
+ header of the model card on the Hugging Face Hub.
+ template_path (`str`, *optional*):
+ A path to a markdown file with optional Jinja template variables that can be filled
+ in with `template_kwargs`. Defaults to the default template.
+
+ Returns:
+ [`huggingface_hub.ModelCard`]: A ModelCard instance with the specified card data and content from the
+ template.
+
+ Example:
+ ```python
+ >>> from huggingface_hub import ModelCard, ModelCardData, EvalResult
+
+ >>> # Using the Default Template
+ >>> card_data = ModelCardData(
+ ... language='en',
+ ... license='mit',
+ ... library_name='timm',
+ ... tags=['image-classification', 'resnet'],
+ ... datasets='beans',
+ ... metrics=['accuracy'],
+ ... )
+ >>> card = ModelCard.from_template(
+ ... card_data,
+ ... model_description='This model does x + y...'
+ ... )
+
+ >>> # Including Evaluation Results
+ >>> card_data = ModelCardData(
+ ... language='en',
+ ... tags=['image-classification', 'resnet'],
+ ... eval_results=[
+ ... EvalResult(
+ ... task_type='image-classification',
+ ... dataset_type='beans',
+ ... dataset_name='Beans',
+ ... metric_type='accuracy',
+ ... metric_value=0.9,
+ ... ),
+ ... ],
+ ... model_name='my-cool-model',
+ ... )
+ >>> card = ModelCard.from_template(card_data)
+
+ >>> # Using a Custom Template
+ >>> card_data = ModelCardData(
+ ... language='en',
+ ... tags=['image-classification', 'resnet']
+ ... )
+ >>> card = ModelCard.from_template(
+ ... card_data=card_data,
+ ... template_path='./src/huggingface_hub/templates/modelcard_template.md',
+ ... custom_template_var='custom value', # will be replaced in template if it exists
+ ... )
+
+ ```
+ """
+ return super().from_template(card_data, template_path, **template_kwargs)
+
+
+class DatasetCard(RepoCard):
+ card_data_class = DatasetCardData
+ default_template_path = TEMPLATE_DATASETCARD_PATH
+ repo_type = "dataset"
+
+ @classmethod
+ def from_template(
+ cls,
+ card_data: DatasetCardData,
+ template_path: Optional[str] = None,
+ **template_kwargs,
+ ):
+ """Initialize a DatasetCard from a template. By default, it uses the default template, which can be found here:
+ https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/templates/datasetcard_template.md
+
+ Templates are Jinja2 templates that can be customized by passing keyword arguments.
+
+ Args:
+ card_data (`huggingface_hub.DatasetCardData`):
+ A huggingface_hub.DatasetCardData instance containing the metadata you want to include in the YAML
+ header of the dataset card on the Hugging Face Hub.
+ template_path (`str`, *optional*):
+ A path to a markdown file with optional Jinja template variables that can be filled
+ in with `template_kwargs`. Defaults to the default template.
+
+ Returns:
+ [`huggingface_hub.DatasetCard`]: A DatasetCard instance with the specified card data and content from the
+ template.
+
+ Example:
+ ```python
+ >>> from huggingface_hub import DatasetCard, DatasetCardData
+
+ >>> # Using the Default Template
+ >>> card_data = DatasetCardData(
+ ... language='en',
+ ... license='mit',
+ ... annotations_creators='crowdsourced',
+ ... task_categories=['text-classification'],
+ ... task_ids=['sentiment-classification', 'text-scoring'],
+ ... multilinguality='monolingual',
+ ... pretty_name='My Text Classification Dataset',
+ ... )
+ >>> card = DatasetCard.from_template(
+ ... card_data,
+ ... pretty_name=card_data.pretty_name,
+ ... )
+
+ >>> # Using a Custom Template
+ >>> card_data = DatasetCardData(
+ ... language='en',
+ ... license='mit',
+ ... )
+ >>> card = DatasetCard.from_template(
+ ... card_data=card_data,
+ ... template_path='./src/huggingface_hub/templates/datasetcard_template.md',
+ ... custom_template_var='custom value', # will be replaced in template if it exists
+ ... )
+
+ ```
+ """
+ return super().from_template(card_data, template_path, **template_kwargs)
+
+
+def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]:
+ """Detect the line ending of a string. Used by RepoCard to avoid making huge diff on newlines.
+
+ Uses same implem as in Hub server, keep it in sync.
+
+ Returns:
+ str: The detected line ending of the string.
+ """
+ cr = content.count("\r")
+ lf = content.count("\n")
+ crlf = content.count("\r\n")
+ if cr + lf == 0:
+ return None
+ if crlf == cr and crlf == lf:
+ return "\r\n"
+ if cr > lf:
+ return "\r"
+ else:
+ return "\n"
def metadata_load(local_path: Union[str, Path]) -> Optional[Dict]:
@@ -131,78 +582,75 @@ def metadata_eval_result(
`dict`: a metadata dict with the result from a model evaluated on a dataset.
Example:
- >>> from huggingface_hub import metadata_eval_result
- >>> metadata_eval_result(
- ... model_pretty_name="RoBERTa fine-tuned on ReactionGIF",
- ... task_pretty_name="Text Classification",
- ... task_id="text-classification",
- ... metrics_pretty_name="Accuracy",
- ... metrics_id="accuracy",
- ... metrics_value=0.2662102282047272,
- ... dataset_pretty_name="ReactionJPEG",
- ... dataset_id="julien-c/reactionjpeg",
- ... dataset_config="default",
- ... dataset_split="test",
- ... )
- {
- "model-index": [
- {
- "name": "RoBERTa fine-tuned on ReactionGIF",
- "results": [
- {
- "task": {
- "type": "text-classification",
- "name": "Text Classification",
- },
- "dataset": {
- "name": "ReactionJPEG",
- "type": "julien-c/reactionjpeg",
- "config": "default",
- "split": "test",
- },
- "metrics": [
- {
- "type": "accuracy",
- "value": 0.2662102282047272,
- "name": "Accuracy",
- "verified": False,
- }
- ],
- }
- ],
- }
- ]
- }
+ ```python
+ >>> from huggingface_hub import metadata_eval_result
+ >>> results = metadata_eval_result(
+ ... model_pretty_name="RoBERTa fine-tuned on ReactionGIF",
+ ... task_pretty_name="Text Classification",
+ ... task_id="text-classification",
+ ... metrics_pretty_name="Accuracy",
+ ... metrics_id="accuracy",
+ ... metrics_value=0.2662102282047272,
+ ... dataset_pretty_name="ReactionJPEG",
+ ... dataset_id="julien-c/reactionjpeg",
+ ... dataset_config="default",
+ ... dataset_split="test",
+ ... )
+ >>> results == {
+ ... 'model-index': [
+ ... {
+ ... 'name': 'RoBERTa fine-tuned on ReactionGIF',
+ ... 'results': [
+ ... {
+ ... 'task': {
+ ... 'type': 'text-classification',
+ ... 'name': 'Text Classification'
+ ... },
+ ... 'dataset': {
+ ... 'name': 'ReactionJPEG',
+ ... 'type': 'julien-c/reactionjpeg',
+ ... 'config': 'default',
+ ... 'split': 'test'
+ ... },
+ ... 'metrics': [
+ ... {
+ ... 'type': 'accuracy',
+ ... 'value': 0.2662102282047272,
+ ... 'name': 'Accuracy',
+ ... 'verified': False
+ ... }
+ ... ]
+ ... }
+ ... ]
+ ... }
+ ... ]
+ ... }
+ True
+
+ ```
"""
- model_index = ModelIndex(
- name=model_pretty_name,
- results=[
- SingleResult(
- metrics=[
- SingleMetric(
- type=metrics_id,
- name=metrics_pretty_name,
- value=metrics_value,
- config=metrics_config,
- verified=metrics_verified,
- ),
- ],
- task=SingleResultTask(type=task_id, name=task_pretty_name),
- dataset=SingleResultDataset(
- name=dataset_pretty_name,
- type=dataset_id,
- config=dataset_config,
- split=dataset_split,
- revision=dataset_revision,
- ),
- )
- ],
- )
- # use `dict_factory` to recursively ignore None values
- data = dataclasses.asdict(
- model_index, dict_factory=lambda x: {k: v for (k, v) in x if v is not None}
- )
- return {"model-index": [data]}
+
+ return {
+ "model-index": eval_results_to_model_index(
+ model_name=model_pretty_name,
+ eval_results=[
+ EvalResult(
+ task_name=task_pretty_name,
+ task_type=task_id,
+ metric_name=metrics_pretty_name,
+ metric_type=metrics_id,
+ metric_value=metrics_value,
+ dataset_name=dataset_pretty_name,
+ dataset_type=dataset_id,
+ metric_config=metrics_config,
+ verified=metrics_verified,
+ dataset_config=dataset_config,
+ dataset_split=dataset_split,
+ dataset_revision=dataset_revision,
+ )
+ ],
+ )
+ }
def metadata_update(
@@ -221,18 +669,6 @@ def metadata_update(
"""
Updates the metadata in the README.md of a repository on the Hugging Face Hub.
- Example:
- >>> from huggingface_hub import metadata_update
- >>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF',
- ... 'results': [{'dataset': {'name': 'ReactionGIF',
- ... 'type': 'julien-c/reactiongif'},
- ... 'metrics': [{'name': 'Recall',
- ... 'type': 'recall',
- ... 'value': 0.7762102282047272}],
- ... 'task': {'name': 'Text Classification',
- ... 'type': 'text-classification'}}]}]}
- >>> update_metdata("julien-c/reactiongif-roberta", metadata)
-
Args:
repo_id (`str`):
The name of the repository.
@@ -265,6 +701,21 @@ def metadata_update(
especially useful if the repo is updated / committed to concurrently.
Returns:
`str`: URL of the commit which updated the card metadata.
+
+ Example:
+ ```python
+ >>> from huggingface_hub import metadata_update
+ >>> metadata = {'model-index': [{'name': 'RoBERTa fine-tuned on ReactionGIF',
+ ... 'results': [{'dataset': {'name': 'ReactionGIF',
+ ... 'type': 'julien-c/reactiongif'},
+ ... 'metrics': [{'name': 'Recall',
+ ... 'type': 'recall',
+ ... 'value': 0.7762102282047272}],
+ ... 'task': {'name': 'Text Classification',
+ ... 'type': 'text-classification'}}]}]}
+ >>> url = metadata_update("hf-internal-testing/reactiongif-roberta-card", metadata)
+
+ ```
"""
commit_message = (
commit_message
@@ -272,142 +723,69 @@ def metadata_update(
else "Update metadata with huggingface_hub"
)
- upstream_filepath = hf_hub_download(
- repo_id,
- filename=REPOCARD_NAME,
- repo_type=repo_type,
- use_auth_token=token,
- )
- # work on a copy of the upstream file, to not mess up the cache
- with tempfile.TemporaryDirectory() as tmpdirname:
- filepath = shutil.copy(upstream_filepath, tmpdirname)
-
- existing_metadata = metadata_load(filepath)
-
- for key in metadata:
- # update model index containing the evaluation results
- if key == "model-index":
- if "model-index" not in existing_metadata:
- existing_metadata["model-index"] = metadata["model-index"]
- else:
- # the model-index contains a list of results as used by PwC but only has one element thus we take the first one
- existing_metadata["model-index"][0][
- "results"
- ] = _update_metadata_model_index(
- existing_metadata["model-index"][0]["results"],
- metadata["model-index"][0]["results"],
- overwrite=overwrite,
- )
- # update all fields except model index
- else:
- if key in existing_metadata and not overwrite:
- if existing_metadata[key] != metadata[key]:
- raise ValueError(
- f"""You passed a new value for the existing meta data field '{key}'. Set `overwrite=True` to overwrite existing metadata."""
- )
- else:
- existing_metadata[key] = metadata[key]
-
- # save and push to hub
- metadata_save(filepath, existing_metadata)
-
- return HfApi().upload_file(
- path_or_fileobj=filepath,
- path_in_repo=REPOCARD_NAME,
- repo_id=repo_id,
- repo_type=repo_type,
- token=token,
- commit_message=commit_message,
- commit_description=commit_description,
- create_pr=create_pr,
- revision=revision,
- parent_commit=parent_commit,
- )
-
+ card = ModelCard.load(repo_id, token=token)
-def _update_metadata_model_index(existing_results, new_results, overwrite=False):
- """
- Updates the model-index fields in the metadata. If results with same unique
- features exist they are updated, else a new result is appended. Updating existing
- values is only possible if `overwrite=True`.
-
- Args:
- new_metrics (`List[dict]`):
- List of new metadata results.
- existing_metrics (`List[dict]`):
- List of existing metadata results.
- overwrite (`bool`, *optional*, defaults to `False`):
- If set to `True`, an existing metric values can be overwritten, otherwise
- attempting to overwrite an existing field will cause an error.
+ for key, value in metadata.items():
+ if key == "model-index":
+ model_name, new_results = model_index_to_eval_results(value)
+ if card.data.eval_results is None:
+ card.data.eval_results = new_results
+ card.data.model_name = model_name
+ else:
+ existing_results = card.data.eval_results
- Returns:
- `list`: List of updated metadata results
- """
- for new_result in new_results:
- result_found = False
- for existing_result_index, existing_result in enumerate(existing_results):
- if all(
- new_result[feat] == existing_result[feat]
- for feat in UNIQUE_RESULT_FEATURES
+ for new_result in new_results:
+ result_found = False
+ for existing_result_index, existing_result in enumerate(
+ existing_results
+ ):
+ if all(
+ [
+ new_result.dataset_name == existing_result.dataset_name,
+ new_result.dataset_type == existing_result.dataset_type,
+ new_result.task_type == existing_result.task_type,
+ new_result.task_name == existing_result.task_name,
+ new_result.metric_name == existing_result.metric_name,
+ new_result.metric_type == existing_result.metric_type,
+ ]
+ ):
+ if (
+ new_result.metric_value != existing_result.metric_value
+ and not overwrite
+ ):
+ existing_str = (
+ f"name: {new_result.metric_name}, type:"
+ f" {new_result.metric_type}"
+ )
+ raise ValueError(
+ "You passed a new value for the existing metric"
+ f" '{existing_str}'. Set `overwrite=True` to"
+ " overwrite existing metrics."
+ )
+ result_found = True
+ card.data.eval_results[existing_result_index] = new_result
+ if not result_found:
+ card.data.eval_results.append(new_result)
+ else:
+ if (
+ hasattr(card.data, key)
+ and getattr(card.data, key) is not None
+ and not overwrite
+ and getattr(card.data, key) != value
):
- result_found = True
- existing_results[existing_result_index][
- "metrics"
- ] = _update_metadata_results_metric(
- new_result["metrics"],
- existing_result["metrics"],
- overwrite=overwrite,
+ raise ValueError(
+ f"""You passed a new value for the existing meta data field '{key}'. Set `overwrite=True` to overwrite existing metadata."""
)
- if not result_found:
- existing_results.append(new_result)
- return existing_results
-
-
-def _update_metadata_results_metric(new_metrics, existing_metrics, overwrite=False):
- """
- Updates the metrics list of a result in the metadata. If metrics with same unique
- features exist their values are updated, else a new metric is appended. Updating
- existing values is only possible if `overwrite=True`.
-
- Args:
- new_metrics (`list`):
- List of new metrics.
- existing_metrics (`list`):
- List of existing metrics.
- overwrite (`bool`, *optional*, defaults to `False`):
- If set to `True`, an existing metric values can be overwritten, otherwise
- attempting to overwrite an existing field will cause an error.
+ else:
+ setattr(card.data, key, value)
- Returns:
- `list`: List of updated metrics
- """
- for new_metric in new_metrics:
- metric_exists = False
- for existing_metric_index, existing_metric in enumerate(existing_metrics):
- if all(
- new_metric[feat] == existing_metric[feat]
- for feat in UNIQUE_METRIC_FEATURES
- ):
- if overwrite:
- existing_metrics[existing_metric_index]["value"] = new_metric[
- "value"
- ]
- else:
- # if metric exists and value is not the same throw an error without overwrite flag
- if (
- existing_metrics[existing_metric_index]["value"]
- != new_metric["value"]
- ):
- existing_str = ", ".join(
- f"{feat}: {new_metric[feat]}"
- for feat in UNIQUE_METRIC_FEATURES
- )
- raise ValueError(
- "You passed a new value for the existing metric"
- f" '{existing_str}'. Set `overwrite=True` to overwrite"
- " existing metrics."
- )
- metric_exists = True
- if not metric_exists:
- existing_metrics.append(new_metric)
- return existing_metrics
+ return card.push_to_hub(
+ repo_id,
+ token=token,
+ repo_type=repo_type,
+ commit_message=commit_message,
+ commit_description=commit_description,
+ create_pr=create_pr,
+ revision=revision,
+ parent_commit=parent_commit,
+ )
diff --git a/src/huggingface_hub/repocard_data.py b/src/huggingface_hub/repocard_data.py
new file mode 100644
index 0000000000..0daac082c4
--- /dev/null
+++ b/src/huggingface_hub/repocard_data.py
@@ -0,0 +1,538 @@
+import copy
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import yaml
+
+from .utils.logging import get_logger
+
+
+logger = get_logger(__name__)
+
+
+@dataclass
+class EvalResult:
+ """
+ Flattened representation of individual evaluation results found in model-index of Model Cards.
+
+ For more information on the model-index spec, see https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1.
+
+ Args:
+ task_type (`str`):
+ The task identifier. Example: "image-classification".
+ dataset_type (`str`):
+ The dataset identifier. Example: "common_voice". Use dataset id from https://hf.co/datasets.
+ dataset_name (`str`):
+ A pretty name for the dataset. Example: "Common Voice (French)".
+ metric_type (`str`):
+ The metric identifier. Example: "wer". Use metric id from https://hf.co/metrics.
+ metric_value (`Any`):
+ The metric value. Example: 0.9 or "20.0 ± 1.2".
+ task_name (`str`, *optional*):
+ A pretty name for the task. Example: "Speech Recognition".
+ dataset_config (`str`, *optional*):
+ The name of the dataset configuration used in `load_dataset()`.
+ Example: fr in `load_dataset("common_voice", "fr")`. See the `datasets` docs for more info:
+ https://hf.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name
+ dataset_split (`str`, *optional*):
+ The split used in `load_dataset()`. Example: "test".
+ dataset_revision (`str`, *optional*):
+ The revision (AKA Git Sha) of the dataset used in `load_dataset()`.
+ Example: 5503434ddd753f426f4b38109466949a1217c2bb
+ dataset_args (`Dict[str, Any]`, *optional*):
+ The arguments passed during `Metric.compute()`. Example for `bleu`: `{"max_order": 4}`
+ metric_name (`str`, *optional*):
+ A pretty name for the metric. Example: "Test WER".
+ metric_config (`str`, *optional*):
+ The name of the metric configuration used in `load_metric()`.
+ Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`.
+ See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations
+ metric_args (`Dict[str, Any]`, *optional*):
+ The arguments passed during `Metric.compute()`. Example for `bleu`: max_order: 4
+ verified (`bool`, *optional*):
+ If true, indicates that evaluation was generated by Hugging Face (vs. self-reported).
+ """
+
+ # Required
+
+ # The task identifier
+ # Example: automatic-speech-recognition
+ task_type: str
+
+ # The dataset identifier
+ # Example: common_voice. Use dataset id from https://hf.co/datasets
+ dataset_type: str
+
+ # A pretty name for the dataset.
+ # Example: Common Voice (French)
+ dataset_name: str
+
+ # The metric identifier
+ # Example: wer. Use metric id from https://hf.co/metrics
+ metric_type: str
+
+ # Value of the metric.
+ # Example: 20.0 or "20.0 ± 1.2"
+ metric_value: Any
+
+ # Optional
+
+ # A pretty name for the task.
+ # Example: Speech Recognition
+ task_name: Optional[str] = None
+
+ # The name of the dataset configuration used in `load_dataset()`.
+ # Example: fr in `load_dataset("common_voice", "fr")`.
+ # See the `datasets` docs for more info:
+ # https://huggingface.co/docs/datasets/package_reference/loading_methods#datasets.load_dataset.name
+ dataset_config: Optional[str] = None
+
+ # The split used in `load_dataset()`.
+ # Example: test
+ dataset_split: Optional[str] = None
+
+ # The revision (AKA Git Sha) of the dataset used in `load_dataset()`.
+ # Example: 5503434ddd753f426f4b38109466949a1217c2bb
+ dataset_revision: Optional[str] = None
+
+ # The arguments passed during `Metric.compute()`.
+ # Example for `bleu`: max_order: 4
+ dataset_args: Optional[Dict[str, Any]] = None
+
+ # A pretty name for the metric.
+ # Example: Test WER
+ metric_name: Optional[str] = None
+
+ # The name of the metric configuration used in `load_metric()`.
+ # Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`.
+ # See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations
+ metric_config: Optional[str] = None
+
+ # The arguments passed during `Metric.compute()`.
+ # Example for `bleu`: max_order: 4
+ metric_args: Optional[Dict[str, Any]] = None
+
+ # If true, indicates that evaluation was generated by Hugging Face (vs. self-reported).
+ verified: Optional[bool] = None
+
+
+@dataclass
+class CardData:
+ def __init__(self, **kwargs):
+ self.__dict__.update(kwargs)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Converts CardData to a dict.
+
+ Returns:
+ `dict`: CardData represented as a dictionary ready to be dumped to a YAML
+ block for inclusion in a README.md file.
+ """
+
+ data_dict = copy.deepcopy(self.__dict__)
+ self._to_dict(data_dict)
+ return _remove_none(data_dict)
+
+ def _to_dict(self, data_dict):
+ """Use this method in child classes to alter the dict representation of the data. Alter the dict in-place.
+
+ Args:
+ data_dict (`dict`): The raw dict representation of the card data.
+ """
+ pass
+
+ def to_yaml(self, line_break=None) -> str:
+ """Dumps CardData to a YAML block for inclusion in a README.md file.
+
+ Args:
+ line_break (str, *optional*):
+ The line break to use when dumping to yaml.
+
+ Returns:
+ `str`: CardData represented as a YAML block.
+ """
+ return yaml.dump(self.to_dict(), sort_keys=False, line_break=line_break).strip()
+
+ def __repr__(self):
+ return self.to_yaml()
+
+
+class ModelCardData(CardData):
+ """Model Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
+
+ Args:
+ language (`Union[str, List[str]]`, *optional*):
+ Language of model's training data or metadata. It must be an ISO 639-1, 639-2 or
+ 639-3 code (two/three letters), or a special value like "code", "multilingual". Defaults to `None`.
+ license (`str`, *optional*):
+ License of this model. Example: apache-2.0 or any license from
+ https://huggingface.co/docs/hub/repositories-licenses. Defaults to None.
+ library_name (`str`, *optional*):
+ Name of library used by this model. Example: keras or any library from
+ https://github.com/huggingface/hub-docs/blob/main/js/src/lib/interfaces/Libraries.ts.
+ Defaults to None.
+ tags (`List[str]`, *optional*):
+ List of tags to add to your model that can be used when filtering on the Hugging
+ Face Hub. Defaults to None.
+ datasets (`Union[str, List[str]]`, *optional*):
+ Dataset or list of datasets that were used to train this model. Should be a dataset ID
+ found on https://hf.co/datasets. Defaults to None.
+ metrics (`Union[str, List[str]]`, *optional*):
+ List of metrics used to evaluate this model. Should be a metric name that can be found
+ at https://hf.co/metrics. Example: 'accuracy'. Defaults to None.
+ eval_results (`Union[List[EvalResult], EvalResult]`, *optional*):
+ List of `huggingface_hub.EvalResult` that define evaluation results of the model. If provided,
+ `model_name` kwarg must be provided. Defaults to `None`.
+ model_name (`str`, *optional*):
+ A name for this model. Required if you provide `eval_results`. It is used along with
+ `eval_results` to construct the `model-index` within the card's metadata. The name
+ you supply here is what will be used on PapersWithCode's leaderboards. Defaults to None.
+ kwargs (`dict`, *optional*):
+ Additional metadata that will be added to the model card. Defaults to None.
+
+ Example:
+ ```python
+ >>> from huggingface_hub import ModelCardData
+ >>> card_data = ModelCardData(
+ ... language="en",
+ ... license="mit",
+ ... library_name="timm",
+ ... tags=['image-classification', 'resnet'],
+ ... )
+ >>> card_data.to_dict()
+ {'language': 'en', 'license': 'mit', 'library_name': 'timm', 'tags': ['image-classification', 'resnet']}
+
+ ```
+ """
+
+ def __init__(
+ self,
+ *,
+ language: Optional[Union[str, List[str]]] = None,
+ license: Optional[str] = None,
+ library_name: Optional[str] = None,
+ tags: Optional[List[str]] = None,
+ datasets: Optional[Union[str, List[str]]] = None,
+ metrics: Optional[Union[str, List[str]]] = None,
+ eval_results: Optional[List[EvalResult]] = None,
+ model_name: Optional[str] = None,
+ **kwargs,
+ ):
+ self.language = language
+ self.license = license
+ self.library_name = library_name
+ self.tags = tags
+ self.datasets = datasets
+ self.metrics = metrics
+ self.eval_results = eval_results
+ self.model_name = model_name
+
+ model_index = kwargs.pop("model-index", None)
+ if model_index:
+ try:
+ model_name, eval_results = model_index_to_eval_results(model_index)
+ self.model_name = model_name
+ self.eval_results = eval_results
+ except KeyError:
+ logger.warning(
+ "Invalid model-index. Not loading eval results into CardData."
+ )
+
+ super().__init__(**kwargs)
+
+ if self.eval_results:
+ if type(self.eval_results) == EvalResult:
+ self.eval_results = [self.eval_results]
+ if self.model_name is None:
+ raise ValueError(
+ "Passing `eval_results` requires `model_name` to be set."
+ )
+
+ def _to_dict(self, data_dict):
+ """Format the internal data dict. In this case, we convert eval results to a valid model index"""
+ if self.eval_results is not None:
+ data_dict["model-index"] = eval_results_to_model_index(
+ self.model_name, self.eval_results
+ )
+ del data_dict["eval_results"], data_dict["model_name"]
+
+
+class DatasetCardData(CardData):
+ """Dataset Card Metadata that is used by Hugging Face Hub when included at the top of your README.md
+
+ Args:
+ language (`Union[str, List[str]]`, *optional*):
+ Language of dataset's data or metadata. It must be an ISO 639-1, 639-2 or
+ 639-3 code (two/three letters), or a special value like "code", "multilingual".
+ license (`Union[str, List[str]]`, *optional*):
+ License(s) of this dataset. Example: apache-2.0 or any license from
+ https://huggingface.co/docs/hub/repositories-licenses.
+ annotations_creators (`Union[str, List[str]]`, *optional*):
+ How the annotations for the dataset were created.
+ Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'no-annotation', 'other'.
+ language_creators (`Union[str, List[str]]`, *optional*):
+ How the text-based data in the dataset was created.
+ Options are: 'found', 'crowdsourced', 'expert-generated', 'machine-generated', 'other'
+ multilinguality (`Union[str, List[str]]`, *optional*):
+ Whether the dataset is multilingual.
+ Options are: 'monolingual', 'multilingual', 'translation', 'other'.
+ size_categories (`Union[str, List[str]]`, *optional*):
+ The number of examples in the dataset. Options are: 'n<1K', '1K1T', and 'other'.
+ source_datasets (`Union[str, List[str]]`, *optional*):
+ Indicates whether the dataset is an original dataset or extended from another existing dataset.
+ Options are: 'original' and 'extended'.
+ task_categories (`Union[str, List[str]]`, *optional*):
+ What categories of task does the dataset support?
+ task_ids (`Union[str, List[str]]`, *optional*):
+ What specific tasks does the dataset support?
+ paperswithcode_id (`str`, *optional*):
+ ID of the dataset on PapersWithCode.
+ pretty_name (`str`, *optional*):
+ A more human-readable name for the dataset. (ex. "Cats vs. Dogs")
+ train_eval_index (`Dict`, *optional*):
+ A dictionary that describes the necessary spec for doing evaluation on the Hub.
+ If not provided, it will be gathered from the 'train-eval-index' key of the kwargs.
+ configs (`Union[str, List[str]]`, *optional*):
+ A list of the available dataset configs for the dataset.
+ """
+
+ def __init__(
+ self,
+ *,
+ language: Optional[Union[str, List[str]]] = None,
+ license: Optional[Union[str, List[str]]] = None,
+ annotations_creators: Optional[Union[str, List[str]]] = None,
+ language_creators: Optional[Union[str, List[str]]] = None,
+ multilinguality: Optional[Union[str, List[str]]] = None,
+ size_categories: Optional[Union[str, List[str]]] = None,
+ source_datasets: Optional[Union[str, List[str]]] = None,
+ task_categories: Optional[Union[str, List[str]]] = None,
+ task_ids: Optional[Union[str, List[str]]] = None,
+ paperswithcode_id: Optional[str] = None,
+ pretty_name: Optional[str] = None,
+ train_eval_index: Optional[Dict] = None,
+ configs: Optional[Union[str, List[str]]] = None,
+ **kwargs,
+ ):
+ self.annotations_creators = annotations_creators
+ self.language_creators = language_creators
+ self.language = language
+ self.license = license
+ self.multilinguality = multilinguality
+ self.size_categories = size_categories
+ self.source_datasets = source_datasets
+ self.task_categories = task_categories
+ self.task_ids = task_ids
+ self.paperswithcode_id = paperswithcode_id
+ self.pretty_name = pretty_name
+ self.configs = configs
+
+ # TODO - maybe handle this similarly to EvalResult?
+ self.train_eval_index = train_eval_index or kwargs.pop("train-eval-index", None)
+ super().__init__(**kwargs)
+
+ def _to_dict(self, data_dict):
+ data_dict["train-eval-index"] = data_dict.pop("train_eval_index")
+
+
+def model_index_to_eval_results(
+ model_index: List[Dict[str, Any]]
+) -> Tuple[str, List[EvalResult]]:
+ """Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects.
+
+ A detailed spec of the model index can be found here:
+ https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
+
+ Args:
+ model_index (`List[Dict[str, Any]]`):
+ A model index data structure, likely coming from a README.md file on the
+ Hugging Face Hub.
+
+ Returns:
+ model_name (`str`):
+ The name of the model as found in the model index. This is used as the
+ identifier for the model on leaderboards like PapersWithCode.
+ eval_results (`List[EvalResult]`):
+ A list of `huggingface_hub.EvalResult` objects containing the metrics
+ reported in the provided model_index.
+
+ Example:
+ ```python
+ >>> from huggingface_hub.repocard_data import model_index_to_eval_results
+ >>> # Define a minimal model index
+ >>> model_index = [
+ ... {
+ ... "name": "my-cool-model",
+ ... "results": [
+ ... {
+ ... "task": {
+ ... "type": "image-classification"
+ ... },
+ ... "dataset": {
+ ... "type": "beans",
+ ... "name": "Beans"
+ ... },
+ ... "metrics": [
+ ... {
+ ... "type": "accuracy",
+ ... "value": 0.9
+ ... }
+ ... ]
+ ... }
+ ... ]
+ ... }
+ ... ]
+ >>> model_name, eval_results = model_index_to_eval_results(model_index)
+ >>> model_name
+ 'my-cool-model'
+ >>> eval_results[0].task_type
+ 'image-classification'
+ >>> eval_results[0].metric_type
+ 'accuracy'
+
+ ```
+ """
+
+ eval_results = []
+ for elem in model_index:
+ name = elem["name"]
+ results = elem["results"]
+ for result in results:
+ task_type = result["task"]["type"]
+ task_name = result["task"].get("name")
+ dataset_type = result["dataset"]["type"]
+ dataset_name = result["dataset"]["name"]
+ dataset_config = result["dataset"].get("config")
+ dataset_split = result["dataset"].get("split")
+ dataset_revision = result["dataset"].get("revision")
+ dataset_args = result["dataset"].get("args")
+
+ for metric in result["metrics"]:
+ metric_type = metric["type"]
+ metric_value = metric["value"]
+ metric_name = metric.get("name")
+ metric_args = metric.get("args")
+ metric_config = metric.get("config")
+ verified = metric.get("verified")
+
+ eval_result = EvalResult(
+ task_type=task_type, # Required
+ dataset_type=dataset_type, # Required
+ dataset_name=dataset_name, # Required
+ metric_type=metric_type, # Required
+ metric_value=metric_value, # Required
+ task_name=task_name,
+ dataset_config=dataset_config,
+ dataset_split=dataset_split,
+ dataset_revision=dataset_revision,
+ dataset_args=dataset_args,
+ metric_name=metric_name,
+ metric_args=metric_args,
+ metric_config=metric_config,
+ verified=verified,
+ )
+ eval_results.append(eval_result)
+ return name, eval_results
+
+
+def _remove_none(obj):
+ """
+ Recursively remove `None` values from a dict. Borrowed from: https://stackoverflow.com/a/20558778
+ """
+ if isinstance(obj, (list, tuple, set)):
+ return type(obj)(_remove_none(x) for x in obj if x is not None)
+ elif isinstance(obj, dict):
+ return type(obj)(
+ (_remove_none(k), _remove_none(v))
+ for k, v in obj.items()
+ if k is not None and v is not None
+ )
+ else:
+ return obj
+
+
+def eval_results_to_model_index(
+ model_name: str, eval_results: List[EvalResult]
+) -> List[Dict[str, Any]]:
+ """Takes in given model name and list of `huggingface_hub.EvalResult` and returns a
+ valid model-index that will be compatible with the format expected by the
+ Hugging Face Hub.
+
+ Args:
+ model_name (`str`):
+ Name of the model (ex. "my-cool-model"). This is used as the identifier
+ for the model on leaderboards like PapersWithCode.
+ eval_results (`List[EvalResult]`):
+ List of `huggingface_hub.EvalResult` objects containing the metrics to be
+ reported in the model-index.
+
+ Returns:
+ model_index (`List[Dict[str, Any]]`): The eval_results converted to a model-index.
+
+ Example:
+ ```python
+ >>> from huggingface_hub.repocard_data import eval_results_to_model_index, EvalResult
+ >>> # Define minimal eval_results
+ >>> eval_results = [
+ ... EvalResult(
+ ... task_type="image-classification", # Required
+ ... dataset_type="beans", # Required
+ ... dataset_name="Beans", # Required
+ ... metric_type="accuracy", # Required
+ ... metric_value=0.9, # Required
+ ... )
+ ... ]
+ >>> eval_results_to_model_index("my-cool-model", eval_results)
+ [{'name': 'my-cool-model', 'results': [{'task': {'type': 'image-classification'}, 'dataset': {'name': 'Beans', 'type': 'beans'}, 'metrics': [{'type': 'accuracy', 'value': 0.9}]}]}]
+
+ ```
+ """
+
+ # Metrics are reported on a unique task-and-dataset basis.
+ # Here, we make a map of those pairs and the associated EvalResults.
+ task_and_ds_types_map = defaultdict(list)
+ for eval_result in eval_results:
+ task_and_ds_pair = (eval_result.task_type, eval_result.dataset_type)
+ task_and_ds_types_map[task_and_ds_pair].append(eval_result)
+
+ # Use the map from above to generate the model index data.
+ model_index_data = []
+ for (task_type, dataset_type), results in task_and_ds_types_map.items():
+ data = {
+ "task": {
+ "type": task_type,
+ "name": results[0].task_name,
+ },
+ "dataset": {
+ "name": results[0].dataset_name,
+ "type": dataset_type,
+ "config": results[0].dataset_config,
+ "split": results[0].dataset_split,
+ "revision": results[0].dataset_revision,
+ "args": results[0].dataset_args,
+ },
+ "metrics": [
+ {
+ "type": result.metric_type,
+ "value": result.metric_value,
+ "name": result.metric_name,
+ "config": result.metric_config,
+ "args": result.metric_args,
+ "verified": result.verified,
+ }
+ for result in results
+ ],
+ }
+ model_index_data.append(data)
+
+ # TODO - Check if there cases where this list is longer than one?
+ # Finally, the model index itself is list of dicts.
+ model_index = [
+ {
+ "name": model_name,
+ "results": model_index_data,
+ }
+ ]
+ return _remove_none(model_index)
diff --git a/src/huggingface_hub/repocard_types.py b/src/huggingface_hub/repocard_types.py
deleted file mode 100644
index a87a37e34a..0000000000
--- a/src/huggingface_hub/repocard_types.py
+++ /dev/null
@@ -1,105 +0,0 @@
-from dataclasses import dataclass, field
-from typing import Any, List, Optional
-
-from typing_extensions import TypeAlias
-
-
-ModelIndexSet: TypeAlias = "List[ModelIndex]"
-
-
-@dataclass
-class ModelIndex:
- name: str
- results: "List[SingleResult]"
-
-
-@dataclass
-class SingleMetric:
- type: str
- """
- Example: wer. Use metric id from https://hf.co/metrics
- """
-
- value: Any
- """
- Example: 20.0 or "20.0 ± 1.2"
- """
-
- name: Optional[str] = None
- """
- Example: Test WER
- """
-
- config: Optional[str] = None
- """
- The name of the metric configuration used in `load_metric()`. Example: bleurt-large-512 in `load_metric("bleurt", "bleurt-large-512")`. \
- See the `datasets` docs for more info: https://huggingface.co/docs/datasets/v2.1.0/en/loading#load-configurations
- """
-
- args: Any = field(default=None)
- """
- The arguments passed during `Metric.compute()`. Example for `bleu`: max_order: 4
- """
-
- verified: Optional[bool] = None
- """
- If true, indicates that evaluation was generated by Hugging Face (vs. self-reported).
- """
-
-
-@dataclass
-class SingleResultTask:
- type: str
- """
- Example: automatic-speech-recognition Use task id from
- hhttps://github.com/huggingface/hub-docs/blob/main/js/src/lib/interfaces/Types.ts
- """
-
- name: Optional[str] = None
- """
- Example: Speech Recognition
- """
-
-
-@dataclass
-class SingleResultDataset:
- name: str
- """
- Example: Common Voice (French). A pretty name for the dataset.
- """
-
- type: str
- """
- Example: common_voice. Use dataset id from https://hf.co/datasets
- """
-
- config: Optional[str] = None
- """
- Example: fr. The name of the dataset configuration used in `load_dataset()`
- """
-
- split: Optional[str] = None
- """
- Example: test.
- """
-
- revision: Optional[str] = None
- """
- Example: 5503434ddd753f426f4b38109466949a1217c2bb.
- """
-
- args: Any = None
- """
- Optional. Additional arguments to `load_dataset()`. Example for wikipedia: { language: "en" }
- """
-
-
-@dataclass
-class SingleResult:
- task: "SingleResultTask"
- dataset: "Optional[SingleResultDataset]"
- """
- This will switch to required at some point. in any case, we need them to
- link to PWC
- """
- metrics: "List[SingleMetric]"
diff --git a/src/huggingface_hub/templates/datasetcard_template.md b/src/huggingface_hub/templates/datasetcard_template.md
new file mode 100644
index 0000000000..94285e6142
--- /dev/null
+++ b/src/huggingface_hub/templates/datasetcard_template.md
@@ -0,0 +1,126 @@
+---
+{{ card_data }}
+---
+
+# Dataset Card for {{ pretty_name | default("Dataset Name", true) }}
+
+## Table of Contents
+- [Table of Contents](#table-of-contents)
+- [Dataset Description](#dataset-description)
+ - [Dataset Summary](#dataset-summary)
+ - [Supported Tasks and Leaderboards](#supported-tasks-and-leaderboards)
+ - [Languages](#languages)
+- [Dataset Structure](#dataset-structure)
+ - [Data Instances](#data-instances)
+ - [Data Fields](#data-fields)
+ - [Data Splits](#data-splits)
+- [Dataset Creation](#dataset-creation)
+ - [Curation Rationale](#curation-rationale)
+ - [Source Data](#source-data)
+ - [Annotations](#annotations)
+ - [Personal and Sensitive Information](#personal-and-sensitive-information)
+- [Considerations for Using the Data](#considerations-for-using-the-data)
+ - [Social Impact of Dataset](#social-impact-of-dataset)
+ - [Discussion of Biases](#discussion-of-biases)
+ - [Other Known Limitations](#other-known-limitations)
+- [Additional Information](#additional-information)
+ - [Dataset Curators](#dataset-curators)
+ - [Licensing Information](#licensing-information)
+ - [Citation Information](#citation-information)
+ - [Contributions](#contributions)
+
+## Dataset Description
+
+- **Homepage:**
+- **Repository:**
+- **Paper:**
+- **Leaderboard:**
+- **Point of Contact:**
+
+### Dataset Summary
+
+[More Information Needed]
+
+### Supported Tasks and Leaderboards
+
+[More Information Needed]
+
+### Languages
+
+[More Information Needed]
+
+## Dataset Structure
+
+### Data Instances
+
+[More Information Needed]
+
+### Data Fields
+
+[More Information Needed]
+
+### Data Splits
+
+[More Information Needed]
+
+## Dataset Creation
+
+### Curation Rationale
+
+[More Information Needed]
+
+### Source Data
+
+#### Initial Data Collection and Normalization
+
+[More Information Needed]
+
+#### Who are the source language producers?
+
+[More Information Needed]
+
+### Annotations
+
+#### Annotation process
+
+[More Information Needed]
+
+#### Who are the annotators?
+
+[More Information Needed]
+
+### Personal and Sensitive Information
+
+[More Information Needed]
+
+## Considerations for Using the Data
+
+### Social Impact of Dataset
+
+[More Information Needed]
+
+### Discussion of Biases
+
+[More Information Needed]
+
+### Other Known Limitations
+
+[More Information Needed]
+
+## Additional Information
+
+### Dataset Curators
+
+[More Information Needed]
+
+### Licensing Information
+
+[More Information Needed]
+
+### Citation Information
+
+[More Information Needed]
+
+### Contributions
+
+Thanks to [@github-username](https://github.com/) for adding this dataset.
\ No newline at end of file
diff --git a/src/huggingface_hub/templates/modelcard_template.md b/src/huggingface_hub/templates/modelcard_template.md
new file mode 100644
index 0000000000..8f8a770a0a
--- /dev/null
+++ b/src/huggingface_hub/templates/modelcard_template.md
@@ -0,0 +1,201 @@
+---
+{{card_data}}
+---
+
+# Model Card for {{ model_id | default("Model ID", true) }}
+
+
+
+# Table of Contents
+
+1. [Model Details](#model-details)
+2. [Uses](#uses)
+3. [Bias, Risks, and Limitations](#bias-risks-and-limitations)
+4. [Training Details](#training-details)
+5. [Evaluation](#evaluation)
+6. [Model Examination](#model-examination)
+7. [Environmental Impact](#environmental-impact)
+8. [Technical Specifications](#technical-specifications-optional)
+9. [Citation](#citation)
+10. [Glossary](#glossary-optional)
+11. [More Information](#more-information-optional)
+12. [Model Card Authors](#model-card-authors-optional)
+13. [Model Card Contact](#model-card-contact)
+14. [How To Get Started With the Model](#how-to-get-started-with-the-model)
+
+
+# Model Details
+
+## Model Description
+
+
+
+{{ model_description | default("", true) }}
+
+- **Developed by:** {{ developers | default("[More Information Needed]", true)}}
+- **Shared by [Optional]:** {{ shared_by | default("[More Information Needed]", true)}}
+- **Model type:** {{ model_type | default("[More Information Needed]", true)}}
+- **Language(s) (NLP):** {{ language | default("[More Information Needed]", true)}}
+- **License:** {{ license | default("[More Information Needed]", true)}}
+- **Related Models:** {{ related_models | default("[More Information Needed]", true)}}
+ - **Parent Model:** {{ parent_model | default("[More Information Needed]", true)}}
+- **Resources for more information:** {{ more_resources | default("[More Information Needed]", true)}}
+
+# Uses
+
+
+
+## Direct Use
+
+
+
+{{ direct_use | default("[More Information Needed]", true)}}
+
+## Downstream Use [Optional]
+
+
+
+{{ downstream_use | default("[More Information Needed]", true)}}
+
+## Out-of-Scope Use
+
+
+
+{{ out_of_scope_use | default("[More Information Needed]", true)}}
+
+# Bias, Risks, and Limitations
+
+
+
+{{ bias_risks_limitations | default("[More Information Needed]", true)}}
+
+## Recommendations
+
+
+
+{{ bias_recommendations | default("Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recomendations.", true)}}
+
+# Training Details
+
+## Training Data
+
+
+
+{{ training_data | default("[More Information Needed]", true)}}
+
+## Training Procedure
+
+
+
+### Preprocessing
+
+{{ preprocessing | default("[More Information Needed]", true)}}
+
+### Speeds, Sizes, Times
+
+
+
+{{ speeds_sizes_times | default("[More Information Needed]", true)}}
+
+# Evaluation
+
+
+
+## Testing Data, Factors & Metrics
+
+### Testing Data
+
+
+
+{{ testing_data | default("[More Information Needed]", true)}}
+
+### Factors
+
+
+
+{{ testing_factors | default("[More Information Needed]", true)}}
+
+### Metrics
+
+
+
+{{ testing_metrics | default("[More Information Needed]", true)}}
+
+## Results
+
+{{ results | default("[More Information Needed]", true)}}
+
+# Model Examination
+
+{{ model_examination | default("[More Information Needed]", true)}}
+
+# Environmental Impact
+
+
+
+Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
+
+- **Hardware Type:** {{ hardware | default("[More Information Needed]", true)}}
+- **Hours used:** {{ hours_used | default("[More Information Needed]", true)}}
+- **Cloud Provider:** {{ cloud_provider | default("[More Information Needed]", true)}}
+- **Compute Region:** {{ cloud_region | default("[More Information Needed]", true)}}
+- **Carbon Emitted:** {{ co2_emitted | default("[More Information Needed]", true)}}
+
+# Technical Specifications [optional]
+
+## Model Architecture and Objective
+
+{{ model_specs | default("[More Information Needed]", true)}}
+
+## Compute Infrastructure
+
+{{ compute_infrastructure | default("[More Information Needed]", true)}}
+
+### Hardware
+
+{{ hardware | default("[More Information Needed]", true)}}
+
+### Software
+
+{{ software | default("[More Information Needed]", true)}}
+
+# Citation
+
+
+
+**BibTeX:**
+
+{{ citation_bibtex | default("[More Information Needed]", true)}}
+
+**APA:**
+
+{{ citation_apa | default("[More Information Needed]", true)}}
+
+# Glossary [optional]
+
+
+
+{{ glossary | default("[More Information Needed]", true)}}
+
+# More Information [optional]
+
+{{ more_information | default("[More Information Needed]", true)}}
+
+# Model Card Authors [optional]
+
+{{ model_card_authors | default("[More Information Needed]", true)}}
+
+# Model Card Contact
+
+{{ model_card_contact | default("[More Information Needed]", true)}}
+
+# How to Get Started with the Model
+
+Use the code below to get started with the model.
+
+
+ Click to expand
+
+{{ get_started_code | default("[More Information Needed]", true)}}
+
+
diff --git a/tests/fixtures/cards/sample_datasetcard_simple.md b/tests/fixtures/cards/sample_datasetcard_simple.md
new file mode 100644
index 0000000000..1c402b1b25
--- /dev/null
+++ b/tests/fixtures/cards/sample_datasetcard_simple.md
@@ -0,0 +1,24 @@
+---
+language:
+- en
+license:
+- bsd-3-clause
+annotations_creators:
+- crowdsourced
+- expert-generated
+language_creators:
+- found
+multilinguality:
+- monolingual
+size_categories:
+- n<1K
+task_categories:
+- image-segmentation
+task_ids:
+- semantic-segmentation
+pretty_name: Sample Segmentation
+---
+
+# Dataset Card for Sample Segmentation
+
+This is a sample dataset card for a semantic segmentation dataset.
\ No newline at end of file
diff --git a/tests/fixtures/cards/sample_datasetcard_template.md b/tests/fixtures/cards/sample_datasetcard_template.md
new file mode 100644
index 0000000000..6c6a7a177a
--- /dev/null
+++ b/tests/fixtures/cards/sample_datasetcard_template.md
@@ -0,0 +1,7 @@
+---
+{card_data}
+---
+
+# {{ pretty_name | default("Dataset Name", true)}}
+
+{{ some_data }}
diff --git a/tests/fixtures/cards/sample_invalid_card_data.md b/tests/fixtures/cards/sample_invalid_card_data.md
new file mode 100644
index 0000000000..844fd89b52
--- /dev/null
+++ b/tests/fixtures/cards/sample_invalid_card_data.md
@@ -0,0 +1,7 @@
+---
+[]
+---
+
+# invalid-card-data
+
+This card should fail when trying to load it in because the card data between the `---` is a list instead of a dict.
diff --git a/tests/fixtures/cards/sample_invalid_model_index.md b/tests/fixtures/cards/sample_invalid_model_index.md
new file mode 100644
index 0000000000..affb608299
--- /dev/null
+++ b/tests/fixtures/cards/sample_invalid_model_index.md
@@ -0,0 +1,24 @@
+---
+language: en
+license: mit
+library_name: timm
+tags:
+- pytorch
+- image-classification
+datasets:
+- beans
+metrics:
+- acc
+model-index:
+- name: my-cool-model
+ results:
+ - task:
+ type: image-classification
+ metrics:
+ - type: acc
+ value: 0.9
+---
+
+# Invalid Model Index
+
+In this example, the model index does not define a dataset field. In this case, we'll still initialize CardData, but will leave model-index/eval_results out of it.
diff --git a/tests/fixtures/cards/sample_no_metadata.md b/tests/fixtures/cards/sample_no_metadata.md
new file mode 100644
index 0000000000..af6a7c73e4
--- /dev/null
+++ b/tests/fixtures/cards/sample_no_metadata.md
@@ -0,0 +1,3 @@
+# MyCoolModel
+
+In this example, we don't have any metadata at the top of the file. In cases like these, `CardData` should be instantiated as empty.
diff --git a/tests/fixtures/cards/sample_simple.md b/tests/fixtures/cards/sample_simple.md
new file mode 100644
index 0000000000..292aec19ab
--- /dev/null
+++ b/tests/fixtures/cards/sample_simple.md
@@ -0,0 +1,19 @@
+---
+language:
+- en
+license: mit
+library_name: pytorch-lightning
+tags:
+- pytorch
+- image-classification
+datasets:
+- beans
+metrics:
+- acc
+---
+
+# my-cool-model
+
+## Model description
+
+You can embed local or remote images using `![](...)`
diff --git a/tests/fixtures/cards/sample_simple_model_index.md b/tests/fixtures/cards/sample_simple_model_index.md
new file mode 100644
index 0000000000..15d26c7069
--- /dev/null
+++ b/tests/fixtures/cards/sample_simple_model_index.md
@@ -0,0 +1,29 @@
+---
+language: en
+license: mit
+library_name: timm
+tags:
+- pytorch
+- image-classification
+datasets:
+- beans
+metrics:
+- acc
+model-index:
+- name: my-cool-model
+ results:
+ - task:
+ type: image-classification
+ dataset:
+ type: beans
+ name: Beans
+ metrics:
+ - type: acc
+ value: 0.9
+---
+
+# my-cool-model
+
+## Model description
+
+You can embed local or remote images using `![](...)`
diff --git a/tests/fixtures/cards/sample_template.md b/tests/fixtures/cards/sample_template.md
new file mode 100644
index 0000000000..99a680f8d8
--- /dev/null
+++ b/tests/fixtures/cards/sample_template.md
@@ -0,0 +1,7 @@
+---
+{{card_data}}
+---
+
+# {{ model_name | default("MyModelName", true)}}
+
+{{ some_data }}
diff --git a/tests/fixtures/cards/sample_windows_line_breaks.md b/tests/fixtures/cards/sample_windows_line_breaks.md
new file mode 100644
index 0000000000..d23f29e972
--- /dev/null
+++ b/tests/fixtures/cards/sample_windows_line_breaks.md
@@ -0,0 +1,11 @@
+---
+license: mit
+language: eo
+thumbnail: https://huggingface.co/blog/assets/01_how-to-train/EsperBERTo-thumbnail-v2.png
+widget:
+- text: "Jen la komenco de bela ."
+- text: "Uno du "
+- text: "Jen finiĝas bela ."
+---
+
+# Hello old Windows line breaks
diff --git a/tests/test_repocard.py b/tests/test_repocard.py
index 0c63020ef0..0658723368 100644
--- a/tests/test_repocard.py
+++ b/tests/test_repocard.py
@@ -21,20 +21,33 @@
import pytest
+import requests
import yaml
from huggingface_hub import (
+ DatasetCard,
+ DatasetCardData,
+ EvalResult,
+ ModelCard,
+ ModelCardData,
metadata_eval_result,
metadata_load,
metadata_save,
metadata_update,
)
from huggingface_hub.constants import REPOCARD_NAME
-from huggingface_hub.file_download import hf_hub_download
+from huggingface_hub.file_download import hf_hub_download, is_jinja_available
from huggingface_hub.hf_api import HfApi
+from huggingface_hub.repocard import RepoCard
+from huggingface_hub.repocard_data import CardData
from huggingface_hub.repository import Repository
from huggingface_hub.utils import logging
-from .testing_constants import ENDPOINT_STAGING, TOKEN, USER
+from .testing_constants import (
+ ENDPOINT_STAGING,
+ ENDPOINT_STAGING_BASIC_AUTH,
+ TOKEN,
+ USER,
+)
from .testing_utils import (
expect_deprecation,
repo_name,
@@ -43,6 +56,8 @@
)
+SAMPLE_CARDS_DIR = Path(__file__).parent / "fixtures/cards"
+
ROUND_TRIP_MODELCARD_CASE = """
---
language: no
@@ -124,7 +139,19 @@
repo_name = partial(repo_name, prefix="dummy-hf-hub")
-class RepocardTest(unittest.TestCase):
+def require_jinja(test_case):
+ """
+ Decorator marking a test that requires Jinja2.
+
+ These tests are skipped when Jinja2 is not installed.
+ """
+ if not is_jinja_available():
+ return unittest.skip("test requires Jinja2.")(test_case)
+ else:
+ return test_case
+
+
+class RepocardMetadataTest(unittest.TestCase):
def setUp(self):
os.makedirs(REPOCARD_DIR, exist_ok=True)
@@ -191,7 +218,7 @@ def test_metadata_eval_result(self):
self.assertEqual(content, DUMMY_MODELCARD_EVAL_RESULT.splitlines())
-class RepocardUpdateTest(unittest.TestCase):
+class RepocardMetadataUpdateTest(unittest.TestCase):
_api = HfApi(endpoint=ENDPOINT_STAGING)
@classmethod
@@ -228,7 +255,7 @@ def tearDown(self) -> None:
shutil.rmtree(self.repo_path)
def test_update_dataset_name(self):
- new_datasets_data = {"datasets": "['test/test_dataset']"}
+ new_datasets_data = {"datasets": ["test/test_dataset"]}
metadata_update(
f"{USER}/{self.REPO_NAME}", new_datasets_data, token=self._token
)
@@ -366,3 +393,350 @@ def test_update_new_result_new_dataset(self):
self.repo.git_pull()
updated_metadata = metadata_load(self.repo_path / self.REPO_NAME / "README.md")
self.assertDictEqual(updated_metadata, expected_metadata)
+
+
+class TestCaseWithCapLog(unittest.TestCase):
+ _api = HfApi(endpoint=ENDPOINT_STAGING)
+
+ @pytest.fixture(autouse=True)
+ def inject_fixtures(self, caplog):
+ """Assign pytest caplog as attribute so we can use captured log messages in tests below."""
+ self.caplog = caplog
+
+
+class RepoCardTest(TestCaseWithCapLog):
+ def test_load_repocard_from_file(self):
+ sample_path = SAMPLE_CARDS_DIR / "sample_simple.md"
+ card = RepoCard.load(sample_path)
+ self.assertEqual(
+ card.data.to_dict(),
+ {
+ "language": ["en"],
+ "license": "mit",
+ "library_name": "pytorch-lightning",
+ "tags": ["pytorch", "image-classification"],
+ "datasets": ["beans"],
+ "metrics": ["acc"],
+ },
+ )
+ self.assertTrue(
+ card.text.strip().startswith("# my-cool-model"),
+ "Card text not loaded properly",
+ )
+
+ def test_change_repocard_data(self):
+ sample_path = SAMPLE_CARDS_DIR / "sample_simple.md"
+ card = RepoCard.load(sample_path)
+ card.data.language = ["fr"]
+
+ with tempfile.TemporaryDirectory() as tempdir:
+ updated_card_path = Path(tempdir) / "updated.md"
+ card.save(updated_card_path)
+
+ updated_card = RepoCard.load(updated_card_path)
+ self.assertEqual(
+ updated_card.data.language, ["fr"], "Card data not updated properly"
+ )
+
+ @require_jinja
+ def test_repo_card_from_default_template(self):
+ card = RepoCard.from_template(
+ card_data=CardData(
+ language="en",
+ license="mit",
+ library_name="pytorch",
+ tags=["image-classification", "resnet"],
+ datasets="imagenet",
+ metrics=["acc", "f1"],
+ ),
+ model_id=None,
+ )
+ self.assertIsInstance(card, RepoCard)
+ self.assertTrue(
+ card.text.strip().startswith("# Model Card for Model ID"),
+ "Default model name not set correctly",
+ )
+
+ @require_jinja
+ def test_repo_card_from_default_template_with_model_id(self):
+ card = RepoCard.from_template(
+ card_data=CardData(
+ language="en",
+ license="mit",
+ library_name="pytorch",
+ tags=["image-classification", "resnet"],
+ datasets="imagenet",
+ metrics=["acc", "f1"],
+ ),
+ model_id="my-cool-model",
+ )
+ self.assertTrue(
+ card.text.strip().startswith("# Model Card for my-cool-model"),
+ "model_id not properly set in card template",
+ )
+
+ @require_jinja
+ def test_repo_card_from_custom_template(self):
+ template_path = SAMPLE_CARDS_DIR / "sample_template.md"
+ card = RepoCard.from_template(
+ card_data=CardData(
+ language="en",
+ license="mit",
+ library_name="pytorch",
+ tags="text-classification",
+ datasets="glue",
+ metrics="acc",
+ ),
+ template_path=template_path,
+ some_data="asdf",
+ )
+ self.assertTrue(
+ card.text.endswith("asdf"),
+ "Custom template didn't set jinja variable correctly",
+ )
+
+ def test_repo_card_data_must_be_dict(self):
+ sample_path = SAMPLE_CARDS_DIR / "sample_invalid_card_data.md"
+ with pytest.raises(
+ ValueError, match="repo card metadata block should be a dict"
+ ):
+ RepoCard(sample_path.read_text())
+
+ def test_repo_card_without_metadata(self):
+ sample_path = SAMPLE_CARDS_DIR / "sample_no_metadata.md"
+
+ with self.caplog.at_level(logging.WARNING):
+ card = RepoCard(sample_path.read_text())
+ self.assertIn(
+ "Repo card metadata block was not found. Setting CardData to empty.",
+ self.caplog.text,
+ )
+ self.assertEqual(card.data, CardData())
+
+ def test_validate_repocard(self):
+ sample_path = SAMPLE_CARDS_DIR / "sample_simple.md"
+ card = RepoCard.load(sample_path)
+ card.validate()
+
+ card.data.license = "asdf"
+ with pytest.raises(ValueError, match='- Error: "license" must be one of'):
+ card.validate()
+
+ def test_push_to_hub(self):
+ repo_id = f"{USER}/{repo_name('push-card')}"
+ self._api.create_repo(repo_id, token=TOKEN)
+
+ card_data = CardData(
+ language="en",
+ license="mit",
+ library_name="pytorch",
+ tags=["text-classification"],
+ datasets="glue",
+ metrics="acc",
+ )
+ # Mock what RepoCard.from_template does so we can test w/o Jinja2
+ content = f"{card_data.to_yaml()}\n\n# MyModel\n\nHello, world!"
+ card = RepoCard(content)
+
+ url = f"{ENDPOINT_STAGING_BASIC_AUTH}/{repo_id}/resolve/main/README.md"
+
+ # Check this file doesn't exist (sanity check)
+ with pytest.raises(requests.exceptions.HTTPError):
+ r = requests.get(url)
+ r.raise_for_status()
+
+ # Push the card up to README.md in the repo
+ card.push_to_hub(repo_id, token=TOKEN)
+
+ # No error should occur now, as README.md should exist
+ r = requests.get(url)
+ r.raise_for_status()
+
+ self._api.delete_repo(repo_id=repo_id, token=TOKEN)
+
+ def test_push_and_create_pr(self):
+ repo_id = f"{USER}/{repo_name('pr-card')}"
+ self._api.create_repo(repo_id, token=TOKEN)
+ card_data = CardData(
+ language="en",
+ license="mit",
+ library_name="pytorch",
+ tags=["text-classification"],
+ datasets="glue",
+ metrics="acc",
+ )
+ # Mock what RepoCard.from_template does so we can test w/o Jinja2
+ content = f"{card_data.to_yaml()}\n\n# MyModel\n\nHello, world!"
+ card = RepoCard(content)
+
+ url = f"{ENDPOINT_STAGING_BASIC_AUTH}/api/models/{repo_id}/discussions"
+ r = requests.get(url)
+ data = r.json()
+ self.assertEqual(data["count"], 0)
+ card.push_to_hub(repo_id, token=TOKEN, create_pr=True)
+ r = requests.get(url)
+ data = r.json()
+ self.assertEqual(data["count"], 1)
+
+ self._api.delete_repo(repo_id=repo_id, token=TOKEN)
+
+ def test_preserve_windows_linebreaks(self):
+ card_path = SAMPLE_CARDS_DIR / "sample_windows_line_breaks.md"
+ card = RepoCard.load(card_path)
+ self.assertIn("\r\n", str(card))
+
+
+class ModelCardTest(TestCaseWithCapLog):
+ def test_model_card_with_invalid_model_index(self):
+ """
+ Test that when loading a card that has invalid model-index, no eval_results are added + it logs a warning
+ """
+ sample_path = SAMPLE_CARDS_DIR / "sample_invalid_model_index.md"
+ with self.caplog.at_level(logging.WARNING):
+ card = ModelCard.load(sample_path)
+ self.assertIn(
+ "Invalid model-index. Not loading eval results into CardData.",
+ self.caplog.text,
+ )
+ self.assertIsNone(card.data.eval_results)
+
+ def test_load_model_card_from_file(self):
+ sample_path = SAMPLE_CARDS_DIR / "sample_simple.md"
+ card = ModelCard.load(sample_path)
+ self.assertIsInstance(card, ModelCard)
+ self.assertEqual(
+ card.data.to_dict(),
+ {
+ "language": ["en"],
+ "license": "mit",
+ "library_name": "pytorch-lightning",
+ "tags": ["pytorch", "image-classification"],
+ "datasets": ["beans"],
+ "metrics": ["acc"],
+ },
+ )
+ self.assertTrue(
+ card.text.strip().startswith("# my-cool-model"),
+ "Card text not loaded properly",
+ )
+
+ @require_jinja
+ def test_model_card_from_custom_template(self):
+ template_path = SAMPLE_CARDS_DIR / "sample_template.md"
+ card = ModelCard.from_template(
+ card_data=ModelCardData(
+ language="en",
+ license="mit",
+ library_name="pytorch",
+ tags="text-classification",
+ datasets="glue",
+ metrics="acc",
+ ),
+ template_path=template_path,
+ some_data="asdf",
+ )
+ self.assertIsInstance(card, ModelCard)
+ self.assertTrue(
+ card.text.endswith("asdf"),
+ "Custom template didn't set jinja variable correctly",
+ )
+
+ @require_jinja
+ def test_model_card_from_template_eval_results(self):
+ template_path = SAMPLE_CARDS_DIR / "sample_template.md"
+ card = ModelCard.from_template(
+ card_data=ModelCardData(
+ eval_results=[
+ EvalResult(
+ task_type="text-classification",
+ task_name="Text Classification",
+ dataset_type="julien-c/reactiongif",
+ dataset_name="ReactionGIF",
+ dataset_config="default",
+ dataset_split="test",
+ metric_type="accuracy",
+ metric_value=0.2662102282047272,
+ metric_name="Accuracy",
+ metric_config="default",
+ verified=False,
+ ),
+ ],
+ model_name="RoBERTa fine-tuned on ReactionGIF",
+ ),
+ template_path=template_path,
+ some_data="asdf",
+ )
+ self.assertIsInstance(card, ModelCard)
+ self.assertTrue(card.text.endswith("asdf"))
+ self.assertTrue(card.data.to_dict().get("eval_results") is None)
+ self.assertEqual(
+ str(card)[: len(DUMMY_MODELCARD_EVAL_RESULT)], DUMMY_MODELCARD_EVAL_RESULT
+ )
+
+
+class DatasetCardTest(TestCaseWithCapLog):
+ def test_load_datasetcard_from_file(self):
+ sample_path = SAMPLE_CARDS_DIR / "sample_datasetcard_simple.md"
+ card = DatasetCard.load(sample_path)
+ self.assertEqual(
+ card.data.to_dict(),
+ {
+ "annotations_creators": ["crowdsourced", "expert-generated"],
+ "language_creators": ["found"],
+ "language": ["en"],
+ "license": ["bsd-3-clause"],
+ "multilinguality": ["monolingual"],
+ "size_categories": ["n<1K"],
+ "task_categories": ["image-segmentation"],
+ "task_ids": ["semantic-segmentation"],
+ "pretty_name": "Sample Segmentation",
+ },
+ )
+ self.assertIsInstance(card, DatasetCard)
+ self.assertIsInstance(card.data, DatasetCardData)
+ self.assertTrue(card.text.strip().startswith("# Dataset Card for"))
+
+ @require_jinja
+ def test_dataset_card_from_default_template(self):
+ card_data = DatasetCardData(
+ language="en",
+ license="mit",
+ )
+
+ # Here we check default title when pretty_name not provided.
+ card = DatasetCard.from_template(card_data)
+ self.assertTrue(card.text.strip().startswith("# Dataset Card for Dataset Name"))
+
+ card_data = DatasetCardData(
+ language="en",
+ license="mit",
+ pretty_name="My Cool Dataset",
+ )
+
+ # Here we pass the card data as kwargs as well so template picks up pretty_name.
+ card = DatasetCard.from_template(card_data, **card_data.to_dict())
+ self.assertTrue(
+ card.text.strip().startswith("# Dataset Card for My Cool Dataset")
+ )
+
+ self.assertIsInstance(card, DatasetCard)
+
+ @require_jinja
+ def test_dataset_card_from_custom_template(self):
+ card = DatasetCard.from_template(
+ card_data=DatasetCardData(
+ language="en",
+ license="mit",
+ pretty_name="My Cool Dataset",
+ ),
+ template_path=SAMPLE_CARDS_DIR / "sample_datasetcard_template.md",
+ pretty_name="My Cool Dataset",
+ some_data="asdf",
+ )
+ self.assertIsInstance(card, DatasetCard)
+
+ # Title this time is just # {{ pretty_name }}
+ self.assertTrue(card.text.strip().startswith("# My Cool Dataset"))
+
+ # some_data is at the bottom of the template, so should end with whatever we passed to it
+ self.assertTrue(card.text.strip().endswith("asdf"))
diff --git a/tests/test_repocard_data.py b/tests/test_repocard_data.py
new file mode 100644
index 0000000000..514345b50c
--- /dev/null
+++ b/tests/test_repocard_data.py
@@ -0,0 +1,197 @@
+import unittest
+
+import pytest
+
+import yaml
+from huggingface_hub.repocard_data import (
+ DatasetCardData,
+ EvalResult,
+ ModelCardData,
+ eval_results_to_model_index,
+ model_index_to_eval_results,
+)
+
+
+DUMMY_METADATA_WITH_MODEL_INDEX = """
+language: en
+license: mit
+library_name: timm
+tags:
+- pytorch
+- image-classification
+datasets:
+- beans
+metrics:
+- acc
+model-index:
+- name: my-cool-model
+ results:
+ - task:
+ type: image-classification
+ dataset:
+ type: beans
+ name: Beans
+ metrics:
+ - type: acc
+ value: 0.9
+"""
+
+
+class ModelCardDataTest(unittest.TestCase):
+ def test_eval_results_to_model_index(self):
+ expected_results = yaml.safe_load(DUMMY_METADATA_WITH_MODEL_INDEX)
+
+ eval_results = [
+ EvalResult(
+ task_type="image-classification",
+ dataset_type="beans",
+ dataset_name="Beans",
+ metric_type="acc",
+ metric_value=0.9,
+ ),
+ ]
+
+ model_index = eval_results_to_model_index("my-cool-model", eval_results)
+
+ self.assertEqual(model_index, expected_results["model-index"])
+
+ def test_model_index_to_eval_results(self):
+ model_index = [
+ {
+ "name": "my-cool-model",
+ "results": [
+ {
+ "task": {
+ "type": "image-classification",
+ },
+ "dataset": {
+ "type": "cats_vs_dogs",
+ "name": "Cats vs. Dogs",
+ },
+ "metrics": [
+ {
+ "type": "acc",
+ "value": 0.85,
+ },
+ {
+ "type": "f1",
+ "value": 0.9,
+ },
+ ],
+ },
+ {
+ "task": {
+ "type": "image-classification",
+ },
+ "dataset": {
+ "type": "beans",
+ "name": "Beans",
+ },
+ "metrics": [
+ {
+ "type": "acc",
+ "value": 0.9,
+ }
+ ],
+ },
+ ],
+ }
+ ]
+ model_name, eval_results = model_index_to_eval_results(model_index)
+
+ self.assertEqual(len(eval_results), 3)
+ self.assertEqual(model_name, "my-cool-model")
+ self.assertEqual(eval_results[0].dataset_type, "cats_vs_dogs")
+ self.assertEqual(eval_results[1].metric_type, "f1")
+ self.assertEqual(eval_results[1].metric_value, 0.9)
+ self.assertEqual(eval_results[2].task_type, "image-classification")
+ self.assertEqual(eval_results[2].dataset_type, "beans")
+
+ def test_card_data_requires_model_name_for_eval_results(self):
+ with pytest.raises(
+ ValueError, match="`eval_results` requires `model_name` to be set."
+ ):
+ ModelCardData(
+ eval_results=[
+ EvalResult(
+ task_type="image-classification",
+ dataset_type="beans",
+ dataset_name="Beans",
+ metric_type="acc",
+ metric_value=0.9,
+ ),
+ ],
+ )
+
+ data = ModelCardData(
+ model_name="my-cool-model",
+ eval_results=[
+ EvalResult(
+ task_type="image-classification",
+ dataset_type="beans",
+ dataset_name="Beans",
+ metric_type="acc",
+ metric_value=0.9,
+ ),
+ ],
+ )
+
+ model_index = eval_results_to_model_index(data.model_name, data.eval_results)
+
+ self.assertEqual(model_index[0]["name"], "my-cool-model")
+ self.assertEqual(
+ model_index[0]["results"][0]["task"]["type"], "image-classification"
+ )
+
+ def test_abitrary_incoming_card_data(self):
+ data = ModelCardData(
+ model_name="my-cool-model",
+ eval_results=[
+ EvalResult(
+ task_type="image-classification",
+ dataset_type="beans",
+ dataset_name="Beans",
+ metric_type="acc",
+ metric_value=0.9,
+ ),
+ ],
+ some_abitrary_kwarg="some_value",
+ )
+
+ self.assertEqual(data.some_abitrary_kwarg, "some_value")
+
+ data_dict = data.to_dict()
+ self.assertEqual(data_dict["some_abitrary_kwarg"], "some_value")
+
+
+class DatasetCardDataTest(unittest.TestCase):
+ def test_train_eval_index_keys_updated(self):
+ train_eval_index = [
+ {
+ "config": "plain_text",
+ "task": "text-classification",
+ "task_id": "binary_classification",
+ "splits": {"train_split": "train", "eval_split": "test"},
+ "col_mapping": {"text": "text", "label": "target"},
+ "metrics": [
+ {
+ "type": "accuracy",
+ "name": "Accuracy",
+ },
+ {"type": "f1", "name": "F1 macro", "args": {"average": "macro"}},
+ ],
+ }
+ ]
+ card_data = DatasetCardData(
+ language="en",
+ license="mit",
+ pretty_name="My Cool Dataset",
+ train_eval_index=train_eval_index,
+ )
+ # The init should have popped this out of kwargs and into train_eval_index attr
+ self.assertEqual(card_data.train_eval_index, train_eval_index)
+ # Underlying train_eval_index gets converted to train-eval-index in DatasetCardData._to_dict.
+ # So train_eval_index should be None in the dict
+ self.assertTrue(card_data.to_dict().get("train_eval_index") is None)
+ # And train-eval-index should be in the dict
+ self.assertEqual(card_data.to_dict()["train-eval-index"], train_eval_index)