Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why FusedBitLinear.forward() use F.linear() with float16 inputs? #19

Open
AACengineer opened this issue Jun 13, 2024 · 3 comments
Open

Comments

@AACengineer
Copy link

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import mmfreelm
from transformers import AutoModelForCausalLM, AutoTokenizer
name = '/mnt/workspace/MMfreeLM-370M'
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name).cuda().half()
input_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, "
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids, max_length=32, do_sample=True, top_p=0.4, temperature=0.6)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
Uploading 捕获.JPG…

"The FusedBitLinear.forward() function calls the LayerNormLinearQuantFn.forward() function. Why are both x and w in the F.linear() function float16? Shouldn't x be int8 and w be within the set {-1, 0, 1}?"

@ridgerchu
Copy link
Owner

Hi, this is due to the consideration of speed. We found that the bf16 will get the fastest speed when we try to doing such operations, so we keep this. If you take a look about its inner values, you will find the activation is INT8 and weight is ternary. This operation is so-called fake quantization, using high precision data type but it actually has tailed to the low precision.

@AACengineer
Copy link
Author

As you mentioned you will find the activation is INT8 and weight is ternary ,both inputs to F.linear() are quantized float16 types.
However, F.linear() still involves multiplication operations, which is not entirely consistent with the concept of being matmul-free.Is it possible to implement the functionality of F.linear() using only add/sub and other operators in a GPU environment?

@ridgerchu
Copy link
Owner

Yes, for training, using matmul is the most efficient approach, and matmul-free can be seen as a special case of matmul. Therefore, we still use F.linear here. To the best of my knowledge, it is a little bit hard to leverage matmul-free operations in a GPU environment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants