Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a kwarg to choose if a model loaded from the preset should be compiled or not #1735

Closed
markomitos opened this issue Aug 5, 2024 · 6 comments · Fixed by #1787
Closed
Assignees
Labels
Gemma Gemma model specific issues

Comments

@markomitos
Copy link

Is your feature request related to a problem? Please describe.
When trying toconvert a loaded Gemma 2 model from a preset into a TensorFlow Federated model I ran into the issue that I need the model to not be compiled when loaded from a preset.

Describe the solution you'd like
I want a boolean kwarg in the keras_nlp.models.GemmaCausalLM.from_preset method that can be further passed down to the model here:
https://github.com/keras-team/keras-nlp/blob/c1afb070ded549d0c18fad812f04d4604553fd59/keras_nlp/src/models/task.py#L273

That way there can be an if statement here:
https://github.com/keras-team/keras-nlp/blob/c1afb070ded549d0c18fad812f04d4604553fd59/keras_nlp/src/models/causal_lm.py#L77
Which can determine if the model should be compiled or not depending on the kwarg passed. This can be true by default, that way it will not affect any existing code.

Describe alternatives you've considered
Alternatively you could elimate the compilation from the preset loading all toghether however I think this would affect to much existing code and it is not neccessary for this feature.

Additional context
My use case is that I am trying to add keras 3 support to TensorFlow Federated and I want to load a preexisting model and convert it to a TensorFlow Federated Model. When I removed the compile from the source code this worked without issues.

@github-actions github-actions bot added the Gemma Gemma model specific issues label Aug 5, 2024
@markomitos
Copy link
Author

@grasskin has solved something similar I think

@grasskin
Copy link
Member

grasskin commented Aug 8, 2024

Hi @markomitos, this is a bit different since there's no Keras 2 version that we're maintaining compatibility with by introducing those kwargs. Note that you could also just independently call compile on gemma which accepts all arguments. https://github.com/keras-team/keras-nlp/blob/b890ca9b06375352bb3ef84411fcca4222897d2f/keras_nlp/src/models/gemma/gemma_causal_lm.py#L172 regardless

I'm not sure if we can only use default compilation in the base class @mattdangerw? Piping in kwargs might be fine

@mattdangerw
Copy link
Member

mattdangerw commented Aug 14, 2024

@markomitos I think we'd need to know more about what is failing here.

The idea was that default compilation would not break any workflow in which compile() is called by the end user. compile() won't actually create new variables or XLA function, that will be done on the first call to fit()/predict()/evaluate(). So it's just supposed to be a unobtrusive way to allow for easy getting started usage, which is overridden by anyone hoping to control their own loss, optimizer, metrics.

Sounds like that is relying on some bad assumptions, but what actually broke?

If it's a silly bug, maybe we just fix the bug. If it's something we can't solve reasonably, then adding a compile=True arg to all task constructors (including from_preset()) sgtm.

@markomitos
Copy link
Author

@mattdangerw The reason for failure is this check inside TensorFlow Federated that constructs a TensorFlow Federated model from a tf/keras model and also traces the TensorFlow graph. There is an explicit check whether the model is compiled:

https://github.com/google-parfait/tensorflow-federated/blob/6b8c06150b050452afcbef5eb12c029ff8cdc359/tensorflow_federated/python/learning/models/keras_utils.py#L121

In the internal code I have adapted the compiled check to support both keras 2 and keras 3 models

I am trying to add support for keras 3 in this library and I have modified this class to support keras 3 models as well (unfortunatly I cannot share this code at the time as it is internal). The only thing that was stoping me from converting a keras 3 model to a TensorFlow Federated model is that it is already compiled. When I removed the compile code from the constructor in the source code it worked as intended.

@mattdangerw
Copy link
Member

sgtm!

@markomitos added this here. #1787 not merged yet. will this work for you?

# from_preset.
classifier = keras_nlp.models.BertClassifier.from_preset(
    "bert_base_multi",
    num_classes=2,
    compile=False,
)
# Direct constructor.
classifier = keras_nlp.models.BertClassifier.from_preset(
    backbone=backbone,
    preprecessor=preprocessor,
    num_classes=2,
    compile=False,
)

@markomitos
Copy link
Author

@mattdangerw That is exactly what I was thinking, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Gemma Gemma model specific issues
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants