Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Cohere's Command-R model #3433

Merged
merged 11 commits into from
Mar 27, 2024
Merged

Add support for Cohere's Command-R model #3433

merged 11 commits into from
Mar 27, 2024

Conversation

zeppombal
Copy link
Contributor

https://huggingface.co/CohereForAI/c4ai-command-r-v01

The PR implements the C4AI Command-R model requested in #3330 and #3403.

@youkaichao
Copy link
Member

Please run ./format.sh before submitting the PR and after any new commits to ensure compliance with linter checks. PRs failing to meet linter standards will not be merged.

@holyCowMp3
Copy link

Any movements here? We are so excited with possibility to use Command-R with VLLM :)

@zeppombal
Copy link
Contributor Author

Will run formatting and commit later today.

@AlpinDale
Copy link
Contributor

The implementation here doesn't seem to use the logit_scale from the model's config. We'll needs to multiply the logits pre-sampling by the scale. If this isn't applied, the model becomes increasingly deterministic, and temperature will be highly ineffective (coherent output at temperatures as high as 10 and 15).

KVCache = Tuple[torch.Tensor, torch.Tensor]


class CohereConfig(PretrainedConfig):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the CohereConfig class should be created in the vllm/transformers_utils/configs/cohere.py like the other models.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, you can import directly from transformers like gemma.
But it's only on the main branch and hasn't been released
https://github.com/huggingface/transformers/blob/main/src/transformers/models/cohere/configuration_cohere.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will create it in vllm/transformers_utils/configs/cohere.py for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, Just use PretrainedConfig, no need to add a custon config file since vLLM will load it from config.json.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zeppombal FYI now we have upgraded transformers to 4.39 where you can import CohereConfig directly from transformers.

@zeppombal
Copy link
Contributor Author

The implementation here doesn't seem to use the logit_scale from the model's config. We'll needs to multiply the logits pre-sampling by the scale. If this isn't applied, the model becomes increasingly deterministic, and temperature will be highly ineffective (coherent output at temperatures as high as 10 and 15).

Good point @AlpinDale. Where do you think is a good place to include this? Maybe in the sample method of CohereForCausalLM? I can't find another model in vllm with this requirement.

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 19, 2024

The implementation here doesn't seem to use the logit_scale from the model's config. We'll needs to multiply the logits pre-sampling by the scale. If this isn't applied, the model becomes increasingly deterministic, and temperature will be highly ineffective (coherent output at temperatures as high as 10 and 15).

@zeppombal @AlpinDale There's a PR to make logits scale configurable. After that, we only need to extend llama and make some minor changes to support this model.

@mwbyeon
Copy link

mwbyeon commented Mar 19, 2024

@zeppombal
As @esmeetu mentioned, I think the logit_scale feature can be resolved when PR #3233 is merged.

or you can simply inherit from the Sampler class to implement it.

class LogitScaledSampler(Sampler):
    def __init__(self,
                 logit_scale: float,
                 vocab_size: int,
                 org_vocab_size: Optional[int] = None,
                 ) -> None:
        super().__init__(vocab_size, org_vocab_size)
        self._logit_scale = logit_scale

    def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
                    embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
        logits = super()._get_logits(hidden_states, embedding, embedding_bias)
        logits *= self._logit_scale
        return logits

class CohereForCausalLM(nn.Module):
    def __init__(
        self,
        config: CohereConfig,
        linear_method: Optional[LinearMethodBase] = None,
    ) -> None:
        ...
        self.sampler = LogitScaledSampler(self.config.logit_scale, config.vocab_size)

@t3ga
Copy link
Contributor

t3ga commented Mar 20, 2024

So, what's next?

@zeppombal
Copy link
Contributor Author

@t3ga I think it makes sense to wait for PR #3233 to be merged.

@t3ga
Copy link
Contributor

t3ga commented Mar 21, 2024

@t3ga I think it makes sense to wait for PR #3233 to be merged.

Ready

But also: #3433 (comment)

@zeppombal
Copy link
Contributor Author

@t3ga I think it makes sense to wait for PR #3233 to be merged.

Ready

But also: #3433 (comment)

Great, will incoporate both things soon.

@esmeetu esmeetu added the new model Requests to new models label Mar 21, 2024
@zeppombal
Copy link
Contributor Author

I still want to compare some some generations with transformers, but won't be able to do so for the next couple of days.

@osilverstein
Copy link

Tested with fp8 caching, did not work

@saurabhdash
Copy link

I tried this with FP16 precision on TP4, the generations seem random. I can try to take a look at it over the weekend, if someone hasn't fixed it by then.

@WoosukKwon
Copy link
Collaborator

@youkaichao Could you shepherd this PR so that it can get merged before the next release? This model is pretty important.

@youkaichao
Copy link
Member

So we have to figure out why the output does not match hf version, before we can merge this PR.

@zeppombal
Copy link
Contributor Author

Right now I'm unable to load the model; it is broken in the most recent transformers version (see this issue). I've submitted a PR to their model repo to help fix it.

@ywang96
Copy link
Member

ywang96 commented Mar 25, 2024

@zeppombal I was able to run the model with transformers=4.39.1 and vLLM built from your branch

Here's the code snippet

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import gc
import torch

os.environ["TOKENIZERS_PARALLELISM"]="true"
model_id = "CohereForAI/c4ai-command-r-v01"
tokenizer = AutoTokenizer.from_pretrained(model_id)

prompt = "Hello, how are you?"
samplingparams = SamplingParams(max_tokens=100, temperature=0, skip_special_tokens=False)


# HF
hf_model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
gen_tokens = hf_model.generate(
    input_ids.cuda(), 
    max_new_tokens=100, 
    do_sample=False,
    )
hf_output = tokenizer.decode(gen_tokens[0][len(input_ids[0]):])
del hf_model
gc.collect()
torch.cuda.empty_cache()

# vLLM
vllm_model = LLM(model=model_id, tokenizer=model_id, tensor_parallel_size=2)
output = vllm_model.generate(prompts=[prompt], sampling_params=samplingparams)
vllm_output = output[0].outputs[0].text

print(vllm_output)
print(tokenizer(vllm_output).input_ids == tokenizer(hf_output).input_ids)

and here are the outputs:

 I hope you are well and that you have had a great week. I have been busy working on my new book, which is coming along nicely. I have also been working on some new designs for my Etsy shop. I have been making some new Christmas cards and I have also been working on some new Christmas ornaments. I have been having a lot of fun with the ornaments. I have been making them from vintage images and they are so cute. I will be listing them in my shop soon.
True

I've verified the completion works with TP=2 and TP=4, but there seemed to be something wrong when I tried to use a prompt formatted by its chat template - will investigate more and report back if anyone doesn't get there before me.

Never mind, prompts with chat template work and I'll test it with example.txt we used in our testing suite.

@youkaichao
Copy link
Member

@ywang96 please ping me when you finish the testing.

@zeppombal please resolve the conflict with the main branch.

@ywang96
Copy link
Member

ywang96 commented Mar 25, 2024

@youkaichao Hmmm.. I ran into issue where the outputs are different by 1 or 2 token when I tried to run it on the example.txt prompts - Would be nice if you or @zeppombal can try to repro it. This was done on A100-80GB

from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import gc
import torch

os.environ["TOKENIZERS_PARALLELISM"]="true"
model_id = "CohereForAI/c4ai-command-r-v01"
tokenizer = AutoTokenizer.from_pretrained(model_id)

prompts = [
    "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
    "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
    "Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
    "Describe the basic components of a neural network and how it can be trained.",
    "Write a short story about a robot that dreams for the first time.",
    "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
    "Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
    "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'"
]

# HF
hf_model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto')
chat_prompts = []
hf_token_ids = []
hf_chat_token_ids = []
for prompt in prompts:
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    gen_tokens = hf_model.generate(
        input_ids.cuda(), 
        max_new_tokens=128, 
        do_sample=False,
        )
    hf_token_ids.append(gen_tokens[0][len(input_ids[0]):])

    messages = [{"role": "user", "content": prompt}]
    chat_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    chat_prompts.append(chat_prompt)
    chat_input_ids = tokenizer(chat_prompt, return_tensors="pt").input_ids
    gen_tokens = hf_model.generate(
        chat_input_ids.cuda(), 
        max_new_tokens=128, 
        do_sample=False,
        )
    hf_chat_token_ids.append(gen_tokens[0][len(chat_input_ids[0]):])

del hf_model
gc.collect()
torch.cuda.empty_cache()

# vLLM
vllm_model = LLM(model=model_id, tokenizer=model_id, tensor_parallel_size=4)
samplingparams = SamplingParams(max_tokens=128, temperature=0.0, skip_special_tokens=False)
vllm_outputs = vllm_model.generate(prompts=prompts, sampling_params=samplingparams)
vllm_chat_outputs = vllm_model.generate(prompts=chat_prompts, sampling_params=samplingparams)

for i in range(len(prompts)):
    vllm_token_ids = vllm_outputs[i].outputs[0].token_ids
    vllm_chat_token_ids = vllm_chat_outputs[i].outputs[0].token_ids

    assert vllm_token_ids == hf_token_ids[i].tolist(), f"{i} - vLLM: {vllm_token_ids}\nHF: {hf_token_ids[i].tolist()}"
    assert vllm_chat_token_ids == hf_chat_token_ids[i].tolist(), f"{i} - vLLM: {vllm_chat_token_ids}\nHF: {hf_chat_token_ids[i].tolist()}"

@zeppombal
Copy link
Contributor Author

@ywang96 thanks, I'll try to reproduce.

@esmeetu
Copy link
Collaborator

esmeetu commented Mar 25, 2024

@ywang96 Possibly you missed wrapping the prompts with chat template like <|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>.
vLLM will add bos token default, so we don't need add bos token in template.
This PR's result is the same as hf output in my local environment.

@t3ga
Copy link
Contributor

t3ga commented Mar 27, 2024

I have two questions:

  1. Is anybody tested original quantized model: CohereForAI/c4ai-command-r-v01-4bit? If it works, could u please istructions how you did it?
  2. Can we change context window in this implementation to 128k? If yes, please answer how)))
    Thanks for your answer and for your job, guys

@saurabhdash
Copy link

saurabhdash commented Mar 27, 2024

I have two questions:

  1. Is anybody tested original quantized model: CohereForAI/c4ai-command-r-v01-4bit? If it works, could u please istructions how you did it?
  2. Can we change context window in this implementation to 128k? If yes, please answer how)))
    Thanks for your answer and for your job, guys

The 4-bit model is for use with huggingface as it uses bitsandbytes for the NF4 format.
I think the 128k context can be done by using the argument --max-model-len=128000

@t3ga
Copy link
Contributor

t3ga commented Mar 27, 2024

I have two questions:

  1. Is anybody tested original quantized model: CohereForAI/c4ai-command-r-v01-4bit? If it works, could u please istructions how you did it?
  2. Can we change context window in this implementation to 128k? If yes, please answer how)))
    Thanks for your answer and for your job, guys

The 4-bit model is for use with huggingface as it uses bitsandbytes for the NF4 format. I think the 128k context can be done by using the argument --max-model-len=128000

Maybe we can somehow convert their bnb quantization to gptq for example??? Because those quantization accuracy near perfect. For example versions of this model quantized to GGUF doesn't have this high accuracy as official bnb quantization.

Or maybe you have some examples how i can by myself quantize origianal full size model to GPTQ without accuracy drop at all mertics like multiliguality and RAG compatibility.

@zeppombal
Copy link
Contributor Author

@saurabhdash maybe decreasing gpu_memory_utilization? I think OOM issues can arise in the harness if the instances over which you're performing inference are large enough. Decreasing the memory utilization of the vllm model leaves some space for that.

@t3ga
Copy link
Contributor

t3ga commented Mar 27, 2024

Found quantized model: https://huggingface.co/Cyleux/command-r-gptq

Doesn't work with vLLM:

Transformers 4.39.1

vLLM: zeppombal@9f8a3c7

Used code:

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="Cyleux/command-r-gptq")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

obtained error:

KeyError                                  Traceback (most recent call last)
Cell In[2], line 14
     11 sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
     13 # Create an LLM.
---> 14 llm = LLM(model="Cyleux/command-r-gptq")
     15 # Generate texts from the prompts. The output is a list of RequestOutput objects
     16 # that contain the prompt, generated text, and other information.
     17 outputs = llm.generate(prompts, sampling_params)

File ~/t3ga/vllm/vllm/entrypoints/llm.py:111, in LLM.__init__(self, model, tokenizer, tokenizer_mode, trust_remote_code, tensor_parallel_size, dtype, quantization, revision, tokenizer_revision, seed, gpu_memory_utilization, swap_space, enforce_eager, max_context_len_to_capture, disable_custom_all_reduce, **kwargs)
     92     kwargs["disable_log_stats"] = True
     93 engine_args = EngineArgs(
     94     model=model,
     95     tokenizer=tokenizer,
   (...)
    109     **kwargs,
    110 )
--> 111 self.llm_engine = LLMEngine.from_engine_args(engine_args)
    112 self.request_counter = Counter()

File ~/t3ga/vllm/vllm/engine/llm_engine.py:150, in LLMEngine.from_engine_args(cls, engine_args)
    147     executor_class = GPUExecutor
    149 # Create the LLM engine.
--> 150 engine = cls(*engine_configs,
    151              executor_class=executor_class,
    152              log_stats=not engine_args.disable_log_stats)
    153 return engine

File ~/t3ga/vllm/vllm/engine/llm_engine.py:106, in LLMEngine.__init__(self, model_config, cache_config, parallel_config, scheduler_config, device_config, lora_config, vision_language_config, executor_class, log_stats)
    103 self.detokenizer = Detokenizer(self.tokenizer)
    104 self.seq_counter = Counter()
--> 106 self.model_executor = executor_class(model_config, cache_config,
    107                                      parallel_config, scheduler_config,
    108                                      device_config, lora_config,
    109                                      vision_language_config)
    111 # Ping the tokenizer to ensure liveness if it runs in a
    112 # different process.
    113 self.tokenizer.ping()

File ~/t3ga/vllm/vllm/executor/gpu_executor.py:37, in GPUExecutor.__init__(self, model_config, cache_config, parallel_config, scheduler_config, device_config, lora_config, vision_language_config)
     34 self.vision_language_config = vision_language_config
     36 # Instantiate the worker and load the model to GPU.
---> 37 self._init_worker()
     39 # Profile the memory usage and initialize the cache.
     40 self._init_cache()

File ~/t3ga/vllm/vllm/executor/gpu_executor.py:66, in GPUExecutor._init_worker(self)
     52 self.driver_worker = Worker(
     53     self.model_config,
     54     self.parallel_config,
   (...)
     63     is_driver_worker=True,
     64 )
     65 self.driver_worker.init_device()
---> 66 self.driver_worker.load_model()

File ~/t3ga/vllm/vllm/worker/worker.py:106, in Worker.load_model(self)
    105 def load_model(self):
--> 106     self.model_runner.load_model()

File ~/t3ga/vllm/vllm/worker/model_runner.py:95, in ModelRunner.load_model(self)
     93 def load_model(self) -> None:
     94     with CudaMemoryProfiler() as m:
---> 95         self.model = get_model(
     96             self.model_config,
     97             self.device_config,
     98             lora_config=self.lora_config,
     99             vision_language_config=self.vision_language_config,
    100             parallel_config=self.parallel_config,
    101             scheduler_config=self.scheduler_config)
    103     self.model_memory_usage = m.consumed_memory
    104     logger.info(f"Loading model weights took "
    105                 f"{self.model_memory_usage / float(2**30):.4f} GB")

File ~/t3ga/vllm/vllm/model_executor/model_loader.py:96, in get_model(model_config, device_config, **kwargs)
     93         initialize_dummy_weights(model)
     94     else:
     95         # Load the weights from the cached or downloaded files.
---> 96         model.load_weights(model_config.model, model_config.download_dir,
     97                            model_config.load_format, model_config.revision)
     98 return model.eval()

File ~/t3ga/vllm/vllm/model_executor/models/commandr.py:333, in CohereForCausalLM.load_weights(self, model_name_or_path, cache_dir, load_format, revision)
    331     break
    332 else:
--> 333     param = params_dict[name]
    334     weight_loader = getattr(param, "weight_loader",
    335                             default_weight_loader)
    336     weight_loader(param, loaded_weight)

KeyError: 'model.layers.25.mlp.down_proj.bias'

But with transformers works ok:

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

model_id = "Cyleux/command-r-gptq"
tokenizer = AutoTokenizer.from_pretrained(model_id)

config = AutoConfig.from_pretrained(model_id)
#config.quantization_config["use_exllama"] = True
config.quantization_config["disable_exllama"] = False
config.quantization_config["exllama_config"] = {"version":2}

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", config=config)

# Format message with the command-r chat template
messages = [{"role": "user", "content": "Hi. How are you?"}]
input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
## <BOS_TOKEN><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hello, how are you?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

gen_tokens = model.generate(
    input_ids, 
    max_new_tokens=100, 
    do_sample=True, 
    temperature=0.3,
    )

gen_text = tokenizer.decode(gen_tokens[0])
print(gen_text)

@saurabhdash
Copy link

@saurabhdash maybe decreasing gpu_memory_utilization? I think OOM issues can arise in the harness if the instances over which you're performing inference are large enough. Decreasing the memory utilization of the vllm model leaves some space for that.

I am using a very modest batch size of 4 and it still OOMs. @youkaichao any idea what might be going wrong? The model weights should be 35GB on each GPU and the logits for Likelihoods should be ~8GB.

@t3ga
Copy link
Contributor

t3ga commented Mar 27, 2024

@saurabhdash maybe decreasing gpu_memory_utilization? I think OOM issues can arise in the harness if the instances over which you're performing inference are large enough. Decreasing the memory utilization of the vllm model leaves some space for that.

I am using a very modest batch size of 4 and it still OOMs. @youkaichao any idea what might be going wrong? The model weights should be 35GB on each GPU and the logits for Likelihoods should be ~8GB.

Sometimes cache using another 1 - 3 Gigs of VRAM.

@zeppombal
Copy link
Contributor Author

@t3ga thanks for testing the quantized model. I believe adding support for it could come in another PR.

@saurabhdash do you have the same problem with a different model of similar size? Say, CodeLlama-34B?

@saurabhdash
Copy link

@t3ga thanks for testing the quantized model. I believe adding support for it could come in another PR.

@saurabhdash do you have the same problem with a different model of similar size? Say, CodeLlama-34B?

Okay, so your advice seems to be working. I set the gpu utilization to 0.5 and it seems to run with bs=4. It's only the first GPU that has near full utilization the rest of them seem to be around 50%
Will update once I have a couple benchmarks done.

@youkaichao
Copy link
Member

@zeppombal can you summarize the status quo of this PR? The conversation has been quite long now.

To my understanding, support for quantized version can be made in a separate PR. Let's focus on one thing at a time.

@saurabhdash
Copy link

@zeppombal can you summarize the status quo of this PR? The conversation has been quite long now.

To my understanding, support for quantized version can be made in a separate PR. Let's focus on one thing at a time.

I am running evals as a final sanity check for the model. I agree with you, the quantized version should be it's own PR. I have already run gsm8k and it checks out.

@zeppombal
Copy link
Contributor Author

@youkaichao if the numbers from @saurabhdash come out OK, I'd say the implementation can be trusted and the PR should be ready for merging.

@saurabhdash
Copy link

@zeppombal @youkaichao I ran Hellaswag 10shot and GSM 5shot on TP4. Both numbers check out. This model looks functionally correct and ready to merge.

@saurabhdash
Copy link

@youkaichao I had a question. There seems to be a warning saying the size of the tokenizer is not 256k. Incase, there is tokenizer padding for efficiency, how does vLLM handle that?

@youkaichao
Copy link
Member

@zeppombal could you please merge the main branch into this branch to trigger a CI? I tested in a fresh new environment and it says ValueError: 'cohere' is already used by a Transformers config .

Might be some issue with transformers, because we recently updated the transformers version.

@youkaichao
Copy link
Member

Okay, I did a quick test and the model output is strictly the same as the huggingface transformers version, using tests/distributed/test_basic_distributed_correctness.py .

I also made a subjective test, and the quality of the output is good:

from vllm import LLM, SamplingParams
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="CohereForAI/c4ai-command-r-v01", tensor_parallel_size=2)
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Output:

Prompt: 'Hello, my name is', Generated text: ' Sarah Lee and I am the MS teacher at Berlin Elementary School. I have taught'
Prompt: 'The president of the United States is', Generated text: ' no longer a human being.\nHe is a symbol. A great big,'
Prompt: 'The capital of France is', Generated text: ' a city that deserves your attention. Paris, the City of Love and Lights,'
Prompt: 'The future of AI is', Generated text: ' not an all-knowing, all-controlling super intelligence, but a'

Thanks for your contribution!

@youkaichao youkaichao merged commit 1182607 into vllm-project:main Mar 27, 2024
33 checks passed
@youkaichao
Copy link
Member

There seems to be a warning saying the size of the tokenizer is not 256k. Incase, there is tokenizer padding for efficiency, how does vLLM handle that?

@saurabhdash let me investigate and respond to you later.

@osilverstein
Copy link

Does GPTQ work now @youkaichao

xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 31, 2024
Co-authored-by: José Maria Pombal <[email protected]>
Co-authored-by: youkaichao <[email protected]>
@egortolmachev
Copy link
Contributor

egortolmachev commented Apr 4, 2024

@zeppombal can you summarize the status quo of this PR? The conversation has been quite long now.

To my understanding, support for quantized version can be made in a separate PR. Let's focus on one thing at a time.

Ready: #3849

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new model Requests to new models release-blocker This PR/issue blocks the next release, therefore deserves highest priority
Projects
None yet
Development

Successfully merging this pull request may close these issues.