Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: MiniCPM-Llama3-V-2_5 error when tensor_parallel_size>1 #6946

Closed
LSC527 opened this issue Jul 30, 2024 · 1 comment · Fixed by #6836
Closed

[Bug]: MiniCPM-Llama3-V-2_5 error when tensor_parallel_size>1 #6946

LSC527 opened this issue Jul 30, 2024 · 1 comment · Fixed by #6836
Labels
bug Something isn't working

Comments

@LSC527
Copy link

LSC527 commented Jul 30, 2024

Your current environment

The output of `python collect_env.py`

🐛 Describe the bug

tensor_parallel_size=1 works fine, but error when tensor_parallel_size>1.

import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer

from vllm import LLM, SamplingParams

model_path = "/home/work/MiniCPM-Llama3-V-2_5"

image = Image.open('x.jpg').convert('RGB')

llm = LLM(
    model=model_path,
    trust_remote_code=True,
    tensor_parallel_size=2,
)

question = 'What is in the image?'

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

messages = [{
    'role': 'user',
    'content': f'(<image>./</image>)\n{question}'
}]
prompt = tokenizer.apply_chat_template(messages,
                                        tokenize=False,
                                        add_generation_prompt=True)

sampling_params = SamplingParams(temperature=0.7, max_tokens=512, stop_token_ids=[128001, 128009])

inputs = {
    "prompt": prompt,
    "multi_modal_data": {
        "image": image
    },
}

outputs = llm.generate(inputs, sampling_params=sampling_params)

for o in outputs:
    generated_text = o.outputs[0].text
    print(generated_text)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/home/work/gitclone/vllm-main/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/home/work/gitclone/vllm-main/vllm/worker/worker_base.py", line 65, in start_worker_execution_loop
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     output = self.execute_model(execute_model_req=None)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/home/work/gitclone/vllm-main/vllm/worker/worker_base.py", line 272, in execute_model
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     output = self.model_runner.execute_model(
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/home/work/gitclone/vllm-main/vllm/worker/model_runner.py", line 1354, in execute_model
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     hidden_or_intermediate_states = model_executable(
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/home/work/gitclone/vllm-main/vllm/model_executor/models/minicpmv.py", line 619, in forward
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     vlm_embeddings, vision_hidden_states = self.get_embedding(inputs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/home/work/gitclone/vllm-main/vllm/model_executor/models/minicpmv.py", line 562, in get_embedding
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     vision_hidden_states = self.get_vision_hidden_states(data)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/home/work/gitclone/vllm-main/vllm/model_executor/models/minicpmv.py", line 545, in get_vision_hidden_states
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     vision_embedding = self.vpm(
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py", line 617, in forward
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py", line 162, in forward
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     patch_embeds = self.patch_embedding(pixel_values)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 460, in forward
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return self._conv_forward(input, self.weight, self.bias)
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226]     return F.conv2d(input, weight, bias, self.stride,
(VllmWorkerProcess pid=11819) ERROR 07-30 20:15:19 multiproc_worker_utils.py:226] RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument weight in method wrapper_CUDA__cudnn_convolution)
@DarkLight1337
Copy link
Member

Thanks for reporting this! It is related to the broadcasting logic and should be fixed soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants