Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support starcoder2 architecture #3089

Merged
merged 12 commits into from
Feb 29, 2024
Merged

Support starcoder2 architecture #3089

merged 12 commits into from
Feb 29, 2024

Conversation

sh0416
Copy link
Contributor

@sh0416 sh0416 commented Feb 28, 2024

#3075

I do my best to support starcoder2 architecture.

Since Huggingface transformer currently supports starcoder2 in the main development branch, it requires to install transformer from source.

My test code is as follow.

Python 3.10.13 (main, Sep 11 2023, 13:44:35) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from vllm import LLM, SamplingParams
>>> llm = LLM(model="bigcode/starcoder2-3b")
INFO 02-28 17:53:55 llm_engine.py:79] Initializing an LLM engine with config: model='bigcode/starcoder2-3b', tokenizer='bigcode/starcoder2-3b', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=16384, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, seed=0)
INFO 02-28 17:53:58 weight_utils.py:163] Using model weights format ['*.safetensors']
INFO 02-28 17:54:02 llm_engine.py:341] # GPU blocks: 139749, # CPU blocks: 8738
INFO 02-28 17:54:05 model_runner.py:676] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 02-28 17:54:05 model_runner.py:680] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode. You can also reduce the `max_num_seqs` as needed to decrease memory usage.
INFO 02-28 17:54:06 model_runner.py:748] Graph capturing finished in 2 secs.
>>> llm.generate(['def print_hello_world():'], SamplingParams(temperature=0.01, top_p=0.95, max_tokens=32))
Processed prompts: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  3.38it/s]
[RequestOutput(request_id=0, prompt='def print_hello_world():', prompt_token_ids=[610, 1489, 100, 7670, 100, 5879, 2284], prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='\n    print("Hello World")\n\ndef print_hello_world_with_name(name):\n    print("Hello World, " + name)\n', token_ids=[303, 1489, 459, 8302, 10914, 678, 222, 222, 610, 1489, 100, 7670, 100, 5879, 100, 1814, 100, 444, 45, 444, 731, 303, 1489, 459, 8302, 10914, 49, 332, 494, 655, 46, 222], cumulative_logprob=0.0, logprobs=None, finish_reason=length)], finished=True, metrics=RequestMetrics(arrival_time=12221976.944665303, last_token_time=12221976.944665303, first_scheduled_time=1709142852.4397888, first_token_time=1709142852.5246108, time_in_queue=1696920875.4951236, finished_time=1709142852.7353413), lora_request=None)]

It seems that the generation result is normal, but I want to review my code for checking whether there exist bugs.

Because starcoder2 codebase comes from Mistral and GPT BigCode, I refer both files in this projects.

@sh0416
Copy link
Contributor Author

sh0416 commented Feb 28, 2024

@esmeetu Could you check it for me?

@WoosukKwon WoosukKwon requested a review from esmeetu February 28, 2024 18:08
@esmeetu
Copy link
Collaborator

esmeetu commented Feb 28, 2024

@sh0416 Thanks for your quick PR! Left some comments, and could you add this new model to README and the document. And also adding this new model to the CI test in test_model.py to see there's no difference with huggingface transformers's.

@esmeetu
Copy link
Collaborator

esmeetu commented Feb 28, 2024

Another request: Could you help supporting other starcoder2 series models(7B and 15B)? I think there's only small differences between them.

@WoosukKwon WoosukKwon mentioned this pull request Feb 28, 2024
5 tasks
@sh0416
Copy link
Contributor Author

sh0416 commented Feb 29, 2024

I've been reflecting whole comments, so please review my revision.

There are two key differences in this branch.

  • Starcoder2 is supported in main developer branch.
    • I will remove it when Starcoder2 implemented is merged into the stable version such as 4.39.0 or newer.
    • Currently, I adopt try-except logic to import Starcoder2Config and fallback to PretrainedConfig if it doesn't exist.
  • Starcoder2-3b doesn't have architectures field in their config.
    • Therefore, the custom Starcoder2Config is neccessary to handle this issue.
    • If architectures is added to the config file, I will remove the code for custom config.

Finally, I've been test Starcoder2-7b and Starcoder2-15b and they produce correct using the given CI test, but does not include in the testcase due to the expensive computational cost of large model.

Do I have to report this result?

Relevant link: https://huggingface.co/bigcode/starcoder2-3b/discussions/2

@sh0416
Copy link
Contributor Author

sh0416 commented Feb 29, 2024

Oh, Starcoder2-15b does not pass the test. Please wait a minute.

@sh0416
Copy link
Contributor Author

sh0416 commented Feb 29, 2024

Starcoder2-15b has different lm head weights (instead of tying embedding table), and now 15b passes the test.

@sh0416 sh0416 requested a review from esmeetu February 29, 2024 05:34
@esmeetu
Copy link
Collaborator

esmeetu commented Feb 29, 2024

@sh0416 Good job! As i have read Starcoder2Config, there's no attribute_map here, so i think we can safely replace that with PretrainedConfig. Could you try and confirm it's right?
LGTM! Thanks for your contribution!

@sh0416
Copy link
Contributor Author

sh0416 commented Feb 29, 2024

@esmeetu Thanks for fast response.

I don't know why some actions that are irrelevant from this development failed.

Could you finalize this PR?

@esmeetu esmeetu enabled auto-merge (squash) February 29, 2024 05:55
auto-merge was automatically disabled February 29, 2024 06:09

Head branch was pushed to by a user without write access

@esmeetu
Copy link
Collaborator

esmeetu commented Feb 29, 2024

@sh0416 Yeah, how did you pass the test locally? It seems broken when loading config file because that repo doesn't have custom configuration file.

@sh0416
Copy link
Contributor Author

sh0416 commented Feb 29, 2024

My test procedure is as follow.

  1. Install vllm from source.
  2. Install transformers from source.
  3. Download starcoder2 weights through AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-3b") and other weights.
  4. pytest tests/models/test_models.py::test_models[128-float-bigcode/starcoder2-3b]

I will repeat this procedure in my local machine and reproduce the error you encountered, so please wait.

@sh0416
Copy link
Contributor Author

sh0416 commented Feb 29, 2024

FYI, it seems that huggingface server is under maintainance and seems unstable currently.

https://x.com/huggingface/status/1762954032312639702?s=20

@sh0416
Copy link
Contributor Author

sh0416 commented Feb 29, 2024

@sh0416 Yeah, how did you pass the test locally? It seems broken when loading config file because that repo doesn't have custom configuration file.

It seems that installing transformers from source is mandatory. AutoConfig in the stable version could not handle Starcoder2Config which exists only in the main developer branch.

@WoosukKwon
Copy link
Collaborator

@sh0416 @esmeetu I added a hack to temporarily fix the issue. Does it look ok?

@esmeetu
Copy link
Collaborator

esmeetu commented Feb 29, 2024

@sh0416 @esmeetu I added a hack to temporarily fix the issue. Does it look ok?

Not bad😎.

@sh0416
Copy link
Contributor Author

sh0416 commented Feb 29, 2024

@sh0416 @esmeetu I added a hack to temporarily fix the issue. Does it look ok?

I've test it with your commit. Actually, the test requires reference generation results from HF transformers, so the error is raised when we check our testcases. However, the usage of vllm is Ok.

from vllm import LLM
llm = LLM("bigcode/starcoder2-3b") # ok

In my opinion, Starcoder2 would be merged in 4.39.0 as it already merged in the main branch, so this issue in the test would be resolved within near future.
If the rule for merging PR is the passed result with HF transformers stable version, then it should not be merged.
However, I think it is good to be merged as the user of vllm could use it without errors.

Thank you

@WoosukKwon
Copy link
Collaborator

@sh0416 Thanks for your input! Generally, we aim to support new models at the earliest possible time and make them more robust later. Given the importance of the StarCoder2 model, I think the current hack is acceptable.

@WoosukKwon WoosukKwon merged commit bfdcfa6 into vllm-project:main Feb 29, 2024
20 of 22 checks passed
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants