Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port gpt2 transformers checkpoint #1704

Merged
merged 12 commits into from
Jul 29, 2024
Merged

Conversation

cosmo3769
Copy link
Contributor

Hi @mattdangerw @ariG23498,

Ported GPT2 transformers checkpoint in kerasNLP. Please check. Thank you!

@cosmo3769
Copy link
Contributor Author

@ariG23498, thank you for your amazing reference repository: ariG23498/keras-nlp-hugging-face-integration 🙏

@mattdangerw
Copy link
Member

Thanks! Why do we need to add a new public facing argument hf_key_prefix? Ideally we would keep our exposed API surface minimal.

@ariG23498
Copy link
Collaborator

@mattdangerw

I have noticed that some model key names have a prefix that breaks model porting.

If one goes to distill-gpt2 they will find the prefix transformer. The simple prefix will break the model porting code.

With the use of hf_key_prefix I introduce a parameter which can be used to set the prefix while model loading.

An example would be

import keras_nlp

model = keras_nlp.models.GPT2CausalLM.from_preset(
    "hf://distilbert/distilgpt2",
    hf_key_prefix="transformer",
)

print(model.generate(["what is"], max_length=15))

@mattdangerw
Copy link
Member

Is there any other prefix we need to check for besides "transformer"? Can we just write some code that either just checks for the "transformer" prefix or any prefix to all the weights? I'm not super familiar with the safetensors API, but there must be a way to list keys right?

I think that'd be a lot more usable, and keep the API clean. In practice I don't think people will actually understand what went wrong, look up the safetensor content, discover the key, and pass it.

@ariG23498
Copy link
Collaborator

@mattdangerw I agree completely with your pointers.

Upon talking to Matt we think it would be easiest to apply a regex to capture the prefix if any. The transformers library captures the prefix and then removes it while loading in their models. Unfortunately we would not have that information.

To be precise this is where we would like to apply the regex code
https://github.com/keras-team/keras-nlp/blob/b6877df38d5ddadcd1f7c9c30498b933b4b6ee30/keras_nlp/src/utils/transformers/safetensor_utils.py#L50

I think that would remove the listing keys and checking altogether.

WDYT?

@cosmo3769
Copy link
Contributor Author

cosmo3769 commented Jul 22, 2024

Upon talking to Matt we think it would be easiest to apply a regex to capture the prefix if any.

Yeah, something like this demo. Using regex here to get the prefix upto layer_index: tested with gpt2 model

@ariG23498 @mattdangerw

@mattdangerw
Copy link
Member

We should make sure to handle both the sharded and single file safetensor case. I do think we could handle this in SafetensorLoader. Since we either get the full list of keys via file.keys() or safetensor_config["weight_map"].keys() we can use that to resolve the actual key name as needed. Let's try to keep the implementation simple.

@mattdangerw
Copy link
Member

Also, please run the formatting script!

Comment on lines 46 to 51
def get_prefix(self, key, all_keys):
for k in all_keys:
if k.endswith(key) and k != key:
prefix = k[: -len(key)]
return prefix + key
return key
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the implementation! WDYT @mattdangerw

My 2 cents:

  • We should use better variable naming.
  • The name of the function is misleading.
  • Adding a docstring here, so that we are well aware of the problem and how to solve it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Come to think of it. I think a better approach to this would be to loop over the keys all at once and have a one-to-one mapping of the hf keys and the keras keys.

Return the map, and then use that map later.

This bypasses the following:

  • Looping over all the keys multiple times
  • Using this function multiple times

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Come to think of it. I think a better approach to this would be to loop over the keys all at once and have a one-to-one mapping of the hf keys and the keras keys.

Makes sense. It will be efficient and real power of this will show up when there will be large number of keys. Mapping it all at once instead of running the loop always surely reduces the time complexity as well.

Copy link
Member

@mattdangerw mattdangerw Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah caching the prefix in some form sounds good.

Is the prefix always the same for all weights? If so, we could probably do something like this...

def get_prefix(key, dict_like):
    if self.prefix is not None:
        return self.prefix
    keys = dict_like.keys()
    if key in keys:
        self.prefix = ""
    else:
        self.prefix = # Some code to figure out the correct prefix.
    return self.prefix

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Mostly minor nits, but the comments on the conversion--don't hardcode the query/key/value size, is important.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM

@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label Jul 29, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Jul 29, 2024
@mattdangerw mattdangerw merged commit cb49405 into keras-team:master Jul 29, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants