-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add keras mixin #230
Add keras mixin #230
Changes from all commits
3f610b4
25dc06d
67767a0
eb68097
58d21ba
db00a40
9a8a211
e7c4dbc
96dd9b4
fc83365
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
from huggingface_hub import ModelHubMixin, hf_hub_download | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class KerasModelHubMixin(ModelHubMixin): | ||
|
||
_CONFIG_NAME = "config.json" | ||
_WEIGHTS_NAME = "tf_model.h5" | ||
|
||
def __init__(self, *args, **kwargs): | ||
nateraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Mix this class with your keras-model class for ease process of saving & loading from huggingface-hub | ||
|
||
NOTE - Dummy Inputs are required to save/load models using this mixin. When saving, you are required to either: | ||
|
||
1. Assign an attribute to your class, self.dummy_inputs, that defines inputs to be passed to the model's call | ||
function to build the model. | ||
2. Pass the dummy_inputs kwarg to save_pretrained. We will save this along with the model (as if it were an attribute). | ||
|
||
Example:: | ||
|
||
>>> from huggingface_hub import KerasModelHubMixin | ||
|
||
>>> class MyModel(tf.keras.Model, KerasModelHubMixin): | ||
... def __init__(self, **kwargs): | ||
... super().__init__() | ||
... self.config = kwargs.pop("config", None) | ||
... self.dummy_inputs = ... | ||
... self.layer = ... | ||
... def call(self, ...) | ||
... return ... | ||
|
||
>>> model = MyModel() | ||
>>> model.save_pretrained("mymodel", push_to_hub=False) # Saving model weights in the directory | ||
>>> model.push_to_hub("mymodel", "model-1") # Pushing model-weights to hf-hub | ||
|
||
>>> # Downloading weights from hf-hub & model will be initialized from those weights | ||
>>> model = MyModel.from_pretrained("username/mymodel@main") | ||
""" | ||
|
||
def _save_pretrained(self, save_directory, dummy_inputs=None, **kwargs): | ||
|
||
dummy_inputs = ( | ||
dummy_inputs | ||
if dummy_inputs is not None | ||
else getattr(self, "dummy_inputs", None) | ||
) | ||
|
||
if dummy_inputs is None: | ||
raise RuntimeError( | ||
"You must either provide dummy inputs or have them assigned as an attribute of this model" | ||
) | ||
nateraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
_ = self(dummy_inputs, training=False) | ||
|
||
save_directory = Path(save_directory) | ||
model_file = save_directory / self._WEIGHTS_NAME | ||
self.save_weights(model_file) | ||
logger.info(f"Model weights saved in {model_file}") | ||
|
||
@classmethod | ||
def _from_pretrained( | ||
cls, | ||
model_id, | ||
revision, | ||
cache_dir, | ||
force_download, | ||
proxies, | ||
resume_download, | ||
local_files_only, | ||
use_auth_token, | ||
Comment on lines
+71
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My intuition tells me all of these should be optional kwargs, even if they're only supposed to be called internally |
||
by_name=False, | ||
**model_kwargs, | ||
): | ||
if os.path.isdir(model_id): | ||
print("Loading weights from local directory") | ||
nateraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_file = os.path.join(model_id, cls._WEIGHTS_NAME) | ||
else: | ||
model_file = hf_hub_download( | ||
repo_id=model_id, | ||
filename=cls._WEIGHTS_NAME, | ||
revision=revision, | ||
cache_dir=cache_dir, | ||
force_download=force_download, | ||
proxies=proxies, | ||
resume_download=resume_download, | ||
use_auth_token=use_auth_token, | ||
local_files_only=local_files_only, | ||
) | ||
|
||
model = cls(**model_kwargs) | ||
|
||
if hasattr(model, "dummy_inputs") and model.dummy_inputs is not None: | ||
raise ValueError("Model must have a dummy_inputs attribute") | ||
|
||
_ = model(model.dummy_inputs, training=False) | ||
|
||
model.load_weights(model_file, by_name=by_name) | ||
|
||
_ = model(model.dummy_inputs, training=False) | ||
|
||
nateraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return model |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import os | ||
import shutil | ||
import time | ||
import unittest | ||
|
||
from huggingface_hub import HfApi | ||
from huggingface_hub.file_download import is_tf_available | ||
from huggingface_hub.keras_mixin import KerasModelHubMixin | ||
|
||
from .testing_constants import ENDPOINT_STAGING, PASS, USER | ||
from .testing_utils import set_write_permission_and_retry | ||
|
||
|
||
REPO_NAME = "mixin-repo-{}".format(int(time.time() * 10e3)) | ||
|
||
WORKING_REPO_SUBDIR = "fixtures/working_repo_3" | ||
WORKING_REPO_DIR = os.path.join( | ||
os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR | ||
) | ||
|
||
if is_tf_available(): | ||
import tensorflow as tf | ||
|
||
|
||
def require_tf(test_case): | ||
""" | ||
Decorator marking a test that requires TensorFlow. | ||
|
||
These tests are skipped when TensorFlow isn't installed. | ||
|
||
""" | ||
if not is_tf_available(): | ||
return unittest.skip("test requires Tensorflow")(test_case) | ||
else: | ||
return test_case | ||
|
||
|
||
if is_tf_available(): | ||
|
||
class DummyModel(tf.keras.Model, KerasModelHubMixin): | ||
def __init__(self, **kwargs): | ||
super().__init__() | ||
self.config = kwargs.pop("config", None) | ||
self.l1 = tf.keras.layers.Dense(2, activation="relu") | ||
dummy_batch_size = input_dim = 2 | ||
self.dummy_inputs = tf.ones([dummy_batch_size, input_dim]) | ||
|
||
def call(self, x): | ||
return self.l1(x) | ||
|
||
|
||
else: | ||
DummyModel = None | ||
|
||
|
||
@require_tf | ||
class HubMixingCommonTest(unittest.TestCase): | ||
_api = HfApi(endpoint=ENDPOINT_STAGING) | ||
Comment on lines
+57
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this super class really needed? We only have one test class so maybe the API wrapper should be instantiated in setUpClass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're probably right. I didn't want to mess with it too much, though. again, just following the previous file. I did find the superclass useful (if I remember right) when I played around with parametrizing the tests so we could run them over TF/PyTorch/etc. |
||
|
||
|
||
@require_tf | ||
class HubMixingTest(HubMixingCommonTest): | ||
nateraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def tearDown(self) -> None: | ||
try: | ||
shutil.rmtree(WORKING_REPO_DIR, onerror=set_write_permission_and_retry) | ||
except FileNotFoundError: | ||
pass | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
""" | ||
Share this valid token in all tests below. | ||
""" | ||
cls._token = cls._api.login(username=USER, password=PASS) | ||
|
||
def test_save_pretrained(self): | ||
model = DummyModel() | ||
|
||
model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}") | ||
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") | ||
self.assertTrue("tf_model.h5" in files) | ||
self.assertEqual(len(files), 1) | ||
|
||
model.save_pretrained( | ||
f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"} | ||
) | ||
files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") | ||
self.assertTrue("config.json" in files) | ||
self.assertTrue("tf_model.h5" in files) | ||
self.assertEqual(len(files), 2) | ||
|
||
def test_keras_from_pretrained_weights(self): | ||
model = DummyModel() | ||
model.dummy_inputs = None | ||
model.save_pretrained( | ||
f"{WORKING_REPO_DIR}/{REPO_NAME}", dummy_inputs=tf.ones([2, 2]) | ||
) | ||
assert model.built | ||
new_model = DummyModel.from_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}") | ||
|
||
# Check the reloaded model's weights match the original model's weights | ||
self.assertTrue(tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0]))) | ||
|
||
# Check a new model's weights are not the same as the reloaded model's weights | ||
another_model = DummyModel() | ||
another_model(tf.ones([2, 2])) | ||
self.assertFalse( | ||
tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0])) | ||
.numpy() | ||
.item() | ||
) | ||
|
||
def test_rel_path_from_pretrained(self): | ||
model = DummyModel() | ||
model.save_pretrained( | ||
f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED", | ||
config={"num": 10, "act": "gelu_fast"}, | ||
) | ||
|
||
model = DummyModel.from_pretrained( | ||
f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED" | ||
) | ||
self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) | ||
|
||
def test_abs_path_from_pretrained(self): | ||
model = DummyModel() | ||
model.save_pretrained( | ||
f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED", | ||
config={"num": 10, "act": "gelu_fast"}, | ||
) | ||
|
||
model = DummyModel.from_pretrained( | ||
f"{WORKING_REPO_DIR}/{REPO_NAME}-FROM_PRETRAINED" | ||
) | ||
self.assertDictEqual(model.config, {"num": 10, "act": "gelu_fast"}) | ||
|
||
def test_push_to_hub(self): | ||
model = DummyModel() | ||
model.push_to_hub( | ||
repo_path_or_name=f"{WORKING_REPO_DIR}/{REPO_NAME}-PUSH_TO_HUB", | ||
api_endpoint=ENDPOINT_STAGING, | ||
use_auth_token=self._token, | ||
git_user="ci", | ||
git_email="[email protected]", | ||
config={"num": 7, "act": "gelu_fast"}, | ||
) | ||
|
||
model_info = self._api.model_info( | ||
f"{USER}/{REPO_NAME}-PUSH_TO_HUB", | ||
) | ||
self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}-PUSH_TO_HUB") | ||
|
||
self._api.delete_repo(token=self._token, name=f"{REPO_NAME}-PUSH_TO_HUB") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
kwargs
are advertised as the kwargs to be passed topush_to_hub
yet they're passed to_save_pretrained
. I'd argue the kwargs to pass to each method ought to be different!