-
-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for Cohere's Command-R model #3433
Conversation
Please run |
Any movements here? We are so excited with possibility to use Command-R with VLLM :) |
Will run formatting and commit later today. |
The implementation here doesn't seem to use the |
KVCache = Tuple[torch.Tensor, torch.Tensor] | ||
|
||
|
||
class CohereConfig(PretrainedConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the CohereConfig
class should be created in the vllm/transformers_utils/configs/cohere.py
like the other models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, you can import directly from transformers like gemma.
But it's only on the main branch and hasn't been released
https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/configuration_cohere.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will create it in vllm/transformers_utils/configs/cohere.py
for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, Just use PretrainedConfig
, no need to add a custon config file since vLLM will load it from config.json
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zeppombal FYI now we have upgraded transformers to 4.39 where you can import CohereConfig
directly from transformers
.
Good point @AlpinDale. Where do you think is a good place to include this? Maybe in the |
@zeppombal @AlpinDale There's a PR to make |
@zeppombal or you can simply inherit from the class LogitScaledSampler(Sampler):
def __init__(self,
logit_scale: float,
vocab_size: int,
org_vocab_size: Optional[int] = None,
) -> None:
super().__init__(vocab_size, org_vocab_size)
self._logit_scale = logit_scale
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
logits = super()._get_logits(hidden_states, embedding, embedding_bias)
logits *= self._logit_scale
return logits
class CohereForCausalLM(nn.Module):
def __init__(
self,
config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
...
self.sampler = LogitScaledSampler(self.config.logit_scale, config.vocab_size) |
So, what's next? |
Ready But also: #3433 (comment) |
Great, will incoporate both things soon. |
I still want to compare some some generations with transformers, but won't be able to do so for the next couple of days. |
Tested with fp8 caching, did not work |
I tried this with FP16 precision on TP4, the generations seem random. I can try to take a look at it over the weekend, if someone hasn't fixed it by then. |
@youkaichao Could you shepherd this PR so that it can get merged before the next release? This model is pretty important. |
So we have to figure out why the output does not match hf version, before we can merge this PR. |
Right now I'm unable to load the model; it is broken in the most recent |
@zeppombal I was able to run the model with Here's the code snippet
and here are the outputs:
I've verified the completion works with TP=2 and TP=4, but Never mind, prompts with chat template work and I'll test it with |
@ywang96 please ping me when you finish the testing. @zeppombal please resolve the conflict with the main branch. |
@youkaichao Hmmm.. I ran into issue where the outputs are different by 1 or 2 token when I tried to run it on the
|
@ywang96 thanks, I'll try to reproduce. |
@ywang96 Possibly you missed wrapping the prompts with chat template like |
I have two questions:
|
The 4-bit model is for use with huggingface as it uses bitsandbytes for the NF4 format. |
Maybe we can somehow convert their bnb quantization to gptq for example??? Because those quantization accuracy near perfect. For example versions of this model quantized to GGUF doesn't have this high accuracy as official bnb quantization. Or maybe you have some examples how i can by myself quantize origianal full size model to GPTQ without accuracy drop at all mertics like multiliguality and RAG compatibility. |
@saurabhdash maybe decreasing |
Found quantized model: https://huggingface.co/Cyleux/command-r-gptq Doesn't work with vLLM: Transformers 4.39.1 vLLM: zeppombal@9f8a3c7 Used code:
obtained error:
But with transformers works ok:
|
I am using a very modest batch size of 4 and it still OOMs. @youkaichao any idea what might be going wrong? The model weights should be 35GB on each GPU and the logits for Likelihoods should be ~8GB. |
Sometimes cache using another 1 - 3 Gigs of VRAM. |
@t3ga thanks for testing the quantized model. I believe adding support for it could come in another PR. @saurabhdash do you have the same problem with a different model of similar size? Say, CodeLlama-34B? |
Okay, so your advice seems to be working. I set the gpu utilization to 0.5 and it seems to run with bs=4. It's only the first GPU that has near full utilization the rest of them seem to be around 50% |
@zeppombal can you summarize the status quo of this PR? The conversation has been quite long now. To my understanding, support for quantized version can be made in a separate PR. Let's focus on one thing at a time. |
I am running evals as a final sanity check for the model. I agree with you, the quantized version should be it's own PR. I have already run gsm8k and it checks out. |
@youkaichao if the numbers from @saurabhdash come out OK, I'd say the implementation can be trusted and the PR should be ready for merging. |
@zeppombal @youkaichao I ran Hellaswag 10shot and GSM 5shot on TP4. Both numbers check out. This model looks functionally correct and ready to merge. |
@youkaichao I had a question. There seems to be a warning saying the size of the tokenizer is not 256k. Incase, there is tokenizer padding for efficiency, how does vLLM handle that? |
@zeppombal could you please merge the main branch into this branch to trigger a CI? I tested in a fresh new environment and it says Might be some issue with transformers, because we recently updated the transformers version. |
Okay, I did a quick test and the model output is strictly the same as the huggingface transformers version, using I also made a subjective test, and the quality of the output is good: from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="CohereForAI/c4ai-command-r-v01", tensor_parallel_size=2)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") Output:
Thanks for your contribution! |
@saurabhdash let me investigate and respond to you later. |
Does GPTQ work now @youkaichao |
Co-authored-by: José Maria Pombal <[email protected]> Co-authored-by: youkaichao <[email protected]>
Ready: #3849 |
https://huggingface.co/CohereForAI/c4ai-command-r-v01
The PR implements the C4AI Command-R model requested in #3330 and #3403.