Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add phi3 #1597

Merged
merged 61 commits into from
May 17, 2024
Merged

Add phi3 #1597

merged 61 commits into from
May 17, 2024

Conversation

abuelnasr0
Copy link
Contributor

@abuelnasr0 abuelnasr0 commented Apr 25, 2024

This PR adds phi3 model.
The PR is not ready for merge yet I need to:

  • Add conversion script and numeric check
  • add Phi3SuScaledRotaryEmbedding for microsoft/Phi-3-mini-128k-instruct

EDIT:
It's ready for review now.

@abuelnasr0
Copy link
Contributor Author

Phi3Backbone is now ready to be reviewed and merged!

I have checked numerics for both models phi3_mini_4k_instruct_en and phi3_mini_128k_instruct_en and it produce good mean differnce between the two outputs for float32 and float16 ( i.e. 5.2247e-08 ), but for bfloat16 it produces a less quality result (i.e. -0.0004). I think this is normal, isn't it?
here is a notebook with the script run on the two models https://www.kaggle.com/code/mohamedabuelnasr/phi3-keras-conversion

@tirthasheshpatel
Copy link
Contributor

tirthasheshpatel commented Apr 29, 2024

Thanks for the work on this @abuelnasr0, this is awesome!

I have checked numerics for both models phi3_mini_4k_instruct_en and phi3_mini_128k_instruct_en and it produce good mean differnce between the two outputs for float32 and float16 ( i.e. 5.2247e-08 ), but for bfloat16 it produces a less quality result (i.e. -0.0004). I think this is normal, isn't it?

Yes, that's pretty good. Are these absolute tolerence values or relative? Either way, it's below the machine precision of 32-bit and 16-bit floating values so they definitely match!

@abuelnasr0
Copy link
Contributor Author

@tirthasheshpatel

Are these absolute tolerence values or relative?

It was supposed to calculate the absolute difference, but I took a look again at the function and I found that I was calculating the mean of the difference not the absolute difference. I have corrected it and run the conversion script again. and the results become worse but acceptable for float32 (3.0725e-06), but for bfloat it is bad (0.0254 for 128k model and 0.0469 for 4k model). and for float16 (0.0046).

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Still should go through this more, but left some initial comments...

Also, see this #1605 (comment), we have to do a large rebase on this. Sorry about that!

keras_nlp/models/phi3/phi3_attention.py Outdated Show resolved Hide resolved
keras_nlp/models/phi3/phi3_attention.py Outdated Show resolved Hide resolved
length that the model was trained with. Defaults to `4096`.
rope_max_wavelength (int, optional): The maximum angular wavelength of
the sine/cosine curves, for rotary embeddings. Defaults to `10000`.
rope_scaling_type (str, optional): The type of the rope scaling. Can be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does su stand for? Just the original author of the RoPE paper? We try to avoid short/inscrutable names like this, but I'm not sure there's a good alternative.

Is this called "su" scaling outside of huggingface anywhere? Also will we need more options here that just two?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I have no idea what "su" stands for, but it was just implemented in the official model repo in huggingface https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/8a362e755d2faf8cec2bf98850ce2216023d178a/modeling_phi3.py#L142
I tried to search for the source of this naming before your question, but I didn't fined anything. It can be standing for the paper author, but the implementation is different from what is proposed in the paper.

Is this called "su" scaling outside of huggingface anywhere?

I didn't see this term anywhere else.

Also will we need more options here that just two?

may be we will need 'yarn' also, If they published the larger models and they are using yarn. https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/blob/8a362e755d2faf8cec2bf98850ce2216023d178a/modeling_phi3.py#L183
yarn is introduced here https://arxiv.org/pdf/2309.00071

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a look again at the phi-3 paper and they actually mentioned that they used LongRope. may be I was in harry when I searched the first time 😅.
So yes 'su' stands for the original paper, but with scaling as introduced in LongRope paper. the layer name in the original implementation is SuScaled, but they made the types names shorter in the config to be only su or yarn. the yarn type is not only yarn but also YarnScaled

keras_nlp/models/phi3/phi3_decoder.py Outdated Show resolved Hide resolved
keras_nlp/models/phi3/phi3_decoder.py Outdated Show resolved Hide resolved
@abuelnasr0
Copy link
Contributor Author

abuelnasr0 commented May 7, 2024

The model is ready now.
And the output matches the huggingface output for the two models.

  • phi3-mini-4k model:

4k-phi3-model

  • phi3-mini-128k model:

128k-phi3-model

but there is a problem with the tokenizer described here:

text = "<|user|>\nHow to win?<|end|>\n<|assistant|>"
# the output after adding special_tokens as user_defined_symbols to the sentence_piece model.
keras_nlp  : [1, 29871, 32010, 13, 5328, 304, 5401, 29973, 32007, 13, 32001]
# same as keras but without adding '▁' at the beginning. can be configured in the spm model.
llama_cpp  : [1, 32010, 13, 5328, 304, 5401, 29973, 32007, 13, 32001] 
# Removes '\n' (LF token) completly. 
# Adds '▁' at the beginning (If text starts with non-special token) and after each special token.
hf_fast_tok: [1, 32010, 1128, 304, 5401, 29973, 32007, 32001]
# Removes '\n' (LF token) completly. Adds '▁' at the beginning.
# Same as keras but if the text doesn't contain '\n'.
hf_tok     : [1, 29871, 32010, 5328, 304, 5401, 29973, 32007, 32001] 

The huggingface output should match the sentencepiece output after adding the special tokens, but huggingface handles special tokens outside the sentencepiece library, that's why output doesn't much.

LlamaTokenizer and LlamaFastTokenizer in huggingface aren't consistent. But if we have to match huggingface output we should try to match LlamaFastTokenizer as it is used in the example of the official model page, but we will do a lot of work around, for example like here #1445

NOTE: The generation match in the photos because I used a text that is tokenized the same in keras_nlp and huggingface using LlamaTokenizer

@abuelnasr0 abuelnasr0 requested a review from mattdangerw May 8, 2024 20:27
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very impressive work!! Thank you.

I think this is basically good to go, just left some comments on arg naming and the behavior of the specialized rope layer.

Thank so much!!

keras_nlp/src/models/phi3/Phi3_preprocessor_test.py Outdated Show resolved Hide resolved
decoder.
max_sequence_length (int, optional): The maximum sequence length
that this model might ever be used with. Defaults to `4096`.
original_max_sequence_length (int, optional): The maximum sequence
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we call this max_sequence_length and training_sequence_length? original_max_sequence_length is just very clunky as a name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it again from training_sequence_length to pretraining_sequence_length. I think this is more clear. I can revert that commit if you would like to keep it training_sequence_length

"padding_mask": padding_mask,
}

def generate(self, inputs, max_length=None, stop_token_ids="auto"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to override generate here? Maybe we should do some refactoring to avoid this need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi3 stops generation at <|end|> and <|endoftext|> tokens by default. The generation will be bad if we don't stop at <|end|>.
Refactoring will be good. may be we can add a variable stop_token_ids to the CausalLMPreprocessor class to be used when stop_token_ids is "auto"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it! Thanks for the explainer. I think what you have now makes sense. Some sort of refactoring so a model can specify default stop tokens without touching the "business logic" of generate sgtm, but that does not need to be on this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, kinda weird that there's two end token ids. Do you know why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<|end|> is EOT (end of turn token) and is used when you are using the model as chat to indicate that <|user|> has ended his turn by writing the prompt or <|assistant|> has ended its turn by generating text, so it will be the turn for the other entity.
<|endoftext|> is just the regular EOS (end of sequence token).

here is an example of model input:

<|user|>\nQuestion<|end|>\n<|assistant|>

keras_nlp/src/models/phi3/phi3_presets.py Outdated Show resolved Hide resolved
tools/sentencepiece_testing/utils.py Show resolved Hide resolved
keras_nlp/src/models/phi3/phi3_rotary_embedding.py Outdated Show resolved Hide resolved
else:
self.inverese_freq_long_factor = None

def _compute_cos_sin_embedding(self, inputs, start_index=0, positions=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks almost exactly the same as what's on the upstream. Is this possible to do just by overriding self._get_inverse_freq(rotary_dim)?

If so, would save a lot of code here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will need to override call() also because cos_embeddings and sine_embeddings are also multiplied by a factor.
https://github.com/keras-team/keras-nlp/blob/0dff9f1eeda8dc37559c7b7a99514c1a8d469c17/keras_nlp/src/models/phi3/phi3_rotary_embedding.py#L129-L136

tools/sentencepiece_testing/create_phi3_test_proto.py Outdated Show resolved Hide resolved
@mattdangerw
Copy link
Member

Copying the presets over now on Kaggle. I will pull this in today.

@mattdangerw
Copy link
Member

Updated links, though I think things are still processing Kaggle side.

@mattdangerw mattdangerw added the kokoro:force-run Runs Tests on GPU label May 17, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label May 17, 2024
@mattdangerw mattdangerw merged commit a675aeb into keras-team:master May 17, 2024
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants