-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a kwarg to choose if a model loaded from the preset should be compiled or not #1735
Comments
@grasskin has solved something similar I think |
Hi @markomitos, this is a bit different since there's no Keras 2 version that we're maintaining compatibility with by introducing those kwargs. Note that you could also just independently call compile on gemma which accepts all arguments. https://github.com/keras-team/keras-nlp/blob/b890ca9b06375352bb3ef84411fcca4222897d2f/keras_nlp/src/models/gemma/gemma_causal_lm.py#L172 regardless I'm not sure if we can only use default compilation in the base class @mattdangerw? Piping in kwargs might be fine |
@markomitos I think we'd need to know more about what is failing here. The idea was that default compilation would not break any workflow in which Sounds like that is relying on some bad assumptions, but what actually broke? If it's a silly bug, maybe we just fix the bug. If it's something we can't solve reasonably, then adding a |
@mattdangerw The reason for failure is this check inside TensorFlow Federated that constructs a TensorFlow Federated model from a tf/keras model and also traces the TensorFlow graph. There is an explicit check whether the model is compiled: In the internal code I have adapted the compiled check to support both keras 2 and keras 3 models I am trying to add support for keras 3 in this library and I have modified this class to support keras 3 models as well (unfortunatly I cannot share this code at the time as it is internal). The only thing that was stoping me from converting a keras 3 model to a TensorFlow Federated model is that it is already compiled. When I removed the compile code from the constructor in the source code it worked as intended. |
sgtm! @markomitos added this here. #1787 not merged yet. will this work for you? # from_preset.
classifier = keras_nlp.models.BertClassifier.from_preset(
"bert_base_multi",
num_classes=2,
compile=False,
)
# Direct constructor.
classifier = keras_nlp.models.BertClassifier.from_preset(
backbone=backbone,
preprecessor=preprocessor,
num_classes=2,
compile=False,
) |
@mattdangerw That is exactly what I was thinking, thank you! |
Is your feature request related to a problem? Please describe.
When trying toconvert a loaded Gemma 2 model from a preset into a TensorFlow Federated model I ran into the issue that I need the model to not be compiled when loaded from a preset.
Describe the solution you'd like
I want a boolean kwarg in the keras_nlp.models.GemmaCausalLM.from_preset method that can be further passed down to the model here:
https://github.com/keras-team/keras-nlp/blob/c1afb070ded549d0c18fad812f04d4604553fd59/keras_nlp/src/models/task.py#L273
That way there can be an if statement here:
https://github.com/keras-team/keras-nlp/blob/c1afb070ded549d0c18fad812f04d4604553fd59/keras_nlp/src/models/causal_lm.py#L77
Which can determine if the model should be compiled or not depending on the kwarg passed. This can be true by default, that way it will not affect any existing code.
Describe alternatives you've considered
Alternatively you could elimate the compilation from the preset loading all toghether however I think this would affect to much existing code and it is not neccessary for this feature.
Additional context
My use case is that I am trying to add keras 3 support to TensorFlow Federated and I want to load a preexisting model and convert it to a TensorFlow Federated Model. When I removed the compile from the source code this worked without issues.
The text was updated successfully, but these errors were encountered: