diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index 60394f569cd8..acabf94d9546 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -454,6 +454,7 @@ def create_metadata(self): metric_mapping = infer_metric_tags_from_eval_results(self.eval_results) metadata = {} + metadata = _insert_value(metadata, "library_name", "transformers") metadata = _insert_values_as_list(metadata, "language", self.language) metadata = _insert_value(metadata, "license", self.license) if self.finetuned_from is not None and isinstance(self.finetuned_from, str) and len(self.finetuned_from) > 0: diff --git a/tests/utils/test_model_card.py b/tests/utils/test_model_card.py index 7d0e8795e0aa..6235bb10ed7b 100644 --- a/tests/utils/test_model_card.py +++ b/tests/utils/test_model_card.py @@ -19,7 +19,7 @@ import tempfile import unittest -from transformers.modelcard import ModelCard +from transformers.modelcard import ModelCard, TrainingSummary class ModelCardTester(unittest.TestCase): @@ -82,3 +82,8 @@ def test_model_card_from_and_save_pretrained(self): model_card_second = ModelCard.from_pretrained(tmpdirname) self.assertEqual(model_card_second.to_dict(), model_card_first.to_dict()) + + def test_model_summary_modelcard_base_metadata(self): + metadata = TrainingSummary("Model name").create_metadata() + self.assertTrue("library_name" in metadata) + self.assertTrue(metadata["library_name"] == "transformers")