-
Notifications
You must be signed in to change notification settings - Fork 260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port gpt2 transformers checkpoint #1704
Conversation
@ariG23498, thank you for your amazing reference repository: ariG23498/keras-nlp-hugging-face-integration 🙏 |
Thanks! Why do we need to add a new public facing argument |
I have noticed that some model key names have a prefix that breaks model porting. If one goes to With the use of An example would be import keras_nlp
model = keras_nlp.models.GPT2CausalLM.from_preset(
"hf://distilbert/distilgpt2",
hf_key_prefix="transformer",
)
print(model.generate(["what is"], max_length=15)) |
Is there any other prefix we need to check for besides I think that'd be a lot more usable, and keep the API clean. In practice I don't think people will actually understand what went wrong, look up the safetensor content, discover the key, and pass it. |
@mattdangerw I agree completely with your pointers. Upon talking to Matt we think it would be easiest to apply a regex to capture the prefix if any. The To be precise this is where we would like to apply the regex code I think that would remove the listing keys and checking altogether. WDYT? |
Yeah, something like this demo. Using regex here to get the prefix upto |
We should make sure to handle both the sharded and single file safetensor case. I do think we could handle this in |
Also, please run the formatting script! |
def get_prefix(self, key, all_keys): | ||
for k in all_keys: | ||
if k.endswith(key) and k != key: | ||
prefix = k[: -len(key)] | ||
return prefix + key | ||
return key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the implementation! WDYT @mattdangerw
My 2 cents:
- We should use better variable naming.
- The name of the function is misleading.
- Adding a docstring here, so that we are well aware of the problem and how to solve it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Come to think of it. I think a better approach to this would be to loop over the keys all at once and have a one-to-one mapping of the hf keys and the keras keys.
Return the map, and then use that map later.
This bypasses the following:
- Looping over all the keys multiple times
- Using this function multiple times
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Come to think of it. I think a better approach to this would be to loop over the keys all at once and have a one-to-one mapping of the hf keys and the keras keys.
Makes sense. It will be efficient and real power of this will show up when there will be large number of keys. Mapping it all at once instead of running the loop always surely reduces the time complexity as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah caching the prefix in some form sounds good.
Is the prefix always the same for all weights? If so, we could probably do something like this...
def get_prefix(key, dict_like):
if self.prefix is not None:
return self.prefix
keys = dict_like.keys()
if key in keys:
self.prefix = ""
else:
self.prefix = # Some code to figure out the correct prefix.
return self.prefix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Mostly minor nits, but the comments on the conversion--don't hardcode the query/key/value size, is important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM
Hi @mattdangerw @ariG23498,
Ported GPT2 transformers checkpoint in kerasNLP. Please check. Thank you!