-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update Keras Mixin #284
Update Keras Mixin #284
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great, thanks for working on this @nateraw! Really stoked about this refactor :)
def save_pretrained_keras( | ||
model, save_directory: str, config: Optional[Dict[str, Any]] = None | ||
): | ||
"""Saves a Keras model to save_directory in SavedModel format. Use this if you're using the Functional or Sequential APIs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the decision to switch from .h5
to SavedModel
is sound.
I think we should monitor issues closely to ensure that this doesn't create an incompatibility/incomprehension layer between transformers
which outputs .h5
files and huggingface_hub
which outputs SavedModel
format.
src/huggingface_hub/keras_mixin.py
Outdated
if use_auth_token is None and repo_url is None: | ||
token = HfFolder.get_token() | ||
if token is None: | ||
raise ValueError( | ||
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and " | ||
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own " | ||
"token as the `use_auth_token` argument." | ||
) | ||
elif isinstance(use_auth_token, str): | ||
token = use_auth_token | ||
else: | ||
token = None | ||
|
||
if repo_path_or_name is None: | ||
repo_path_or_name = repo_url.split("/")[-1] | ||
|
||
# If no URL is passed and there's no path to a directory containing files, create a repo | ||
if repo_url is None and not os.path.exists(repo_path_or_name): | ||
repo_name = Path(repo_path_or_name).name | ||
repo_url = HfApi(endpoint=api_endpoint).create_repo( | ||
token, | ||
repo_name, | ||
organization=organization, | ||
private=private, | ||
repo_type=None, | ||
exist_ok=True, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic should already be handled by the __init__
of Repository
.
The initialization can currently take care of:
- Setting the correct token according to
use_auth_token
according tostr
/bool
- Identifying the repo name from the URL
- Creating the repository if it doesn't exist.
Could you try working with it? Please let me know if you run into any issues, hopefully it makes your life easier.
It would be nice to also add a Github Actions job with Keras to test this automatically on CI! |
Co-authored-by: Lysandre Debut <[email protected]>
123aa09
to
8182c32
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, thanks for working on it and implementing a robust test suite. I've only left nitpicks, this looks about ready to merge!
@@ -57,6 +57,23 @@ jobs: | |||
|
|||
- run: pytest -sv ./tests/ | |||
|
|||
build_tensorflow: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice
I came late to the last round of review, but great work! 🔥 |
Both mentioned in #230
Note - I had to switch to using
snapshot_download
because:keras_model.pb
metadata.pb
,variables/
, etc.Using
forced_filename
also could have worked, but you'd have to gather up all the files withinvariables
/assets
/etcI've seen some folks just zip up the specific file tree instead due to similar issues.