Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: [V1] Mamba models fail on profile run #12826

Open
1 task done
nopperl opened this issue Feb 6, 2025 · 1 comment
Open
1 task done

[Bug]: [V1] Mamba models fail on profile run #12826

nopperl opened this issue Feb 6, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@nopperl
Copy link

nopperl commented Feb 6, 2025

Your current environment

The output of `python collect_env.py`
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (conda-forge gcc 10.4.0-19) 10.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.31.5
Libc version: glibc-2.31

Python version: 3.12.8 | packaged by conda-forge | (main, Dec  5 2024, 14:24:40) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-193-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA L40S

Nvidia driver version: 570.86.10
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True


Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pynvml==11.5.3
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.48.2
[pip3] triton==3.1.0
[conda] Could not collect
ROCM Version: 5.4.22801-aaa1e3d8
Neuron SDK Version: N/A
vLLM Version: 0.7.1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled


NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

Running Mamba models using the v1 engine fails due to an error in the MambaCacheManager during the profile run.

from vllm import LLM, SamplingParams
import os
os.environ["VLLM_USE_V1"] = "1"

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

#llm = LLM(model="ai21labs/Jamba-tiny-dev", enable_prefix_caching=False)
llm = LLM(model="state-spaces/mamba-130m-hf", enable_prefix_caching=False)

outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
ERROR 02-06 12:00:04 core.py:208] EngineCore hit an exception: Traceback (most recent call last):
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 200, in run_engine_core
ERROR 02-06 12:00:04 core.py:208]     engine_core = EngineCoreProc(*args, **kwargs)
ERROR 02-06 12:00:04 core.py:208]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 154, in __init__
ERROR 02-06 12:00:04 core.py:208]     super().__init__(vllm_config, executor_class)
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 52, in __init__
ERROR 02-06 12:00:04 core.py:208]     num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
ERROR 02-06 12:00:04 core.py:208]                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 77, in _initialize_kv_caches
ERROR 02-06 12:00:04 core.py:208]     availble_gpu_memory = self.model_executor.determine_available_memory()
ERROR 02-06 12:00:04 core.py:208]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/v1/executor/abstract.py", line 59, in determine_available_memory
ERROR 02-06 12:00:04 core.py:208]     output = self.collective_rpc("determine_available_memory")
ERROR 02-06 12:00:04 core.py:208]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 49, in collective_rpc
ERROR 02-06 12:00:04 core.py:208]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 02-06 12:00:04 core.py:208]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/utils.py", line 2208, in run_method
ERROR 02-06 12:00:04 core.py:208]     return func(*args, **kwargs)
ERROR 02-06 12:00:04 core.py:208]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 02-06 12:00:04 core.py:208]     return func(*args, **kwargs)
ERROR 02-06 12:00:04 core.py:208]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 163, in determine_available_memory
ERROR 02-06 12:00:04 core.py:208]     self.model_runner.profile_run()
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 985, in profile_run
ERROR 02-06 12:00:04 core.py:208]     hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches)
ERROR 02-06 12:00:04 core.py:208]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 02-06 12:00:04 core.py:208]     return func(*args, **kwargs)
ERROR 02-06 12:00:04 core.py:208]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 870, in _dummy_run
ERROR 02-06 12:00:04 core.py:208]     hidden_states = model(
ERROR 02-06 12:00:04 core.py:208]                     ^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 02-06 12:00:04 core.py:208]     return self._call_impl(*args, **kwargs)
ERROR 02-06 12:00:04 core.py:208]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 02-06 12:00:04 core.py:208]     return forward_call(*args, **kwargs)
ERROR 02-06 12:00:04 core.py:208]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/model_executor/models/mamba.py", line 237, in forward
ERROR 02-06 12:00:04 core.py:208]     ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
ERROR 02-06 12:00:04 core.py:208]         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208]   File "/scratch/micromamba/envs/vllm-0.7.1/lib/python3.12/site-packages/vllm/model_executor/models/mamba_cache.py", line 50, in current_run_tensors
ERROR 02-06 12:00:04 core.py:208]     request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
ERROR 02-06 12:00:04 core.py:208]                              ~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-06 12:00:04 core.py:208] KeyError: 'request_ids_to_seq_ids'
ERROR 02-06 12:00:04 core.py:208]

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@nopperl nopperl added the bug Something isn't working label Feb 6, 2025
@tlrmchlsmth
Copy link
Collaborator

That's right, Mamba models are unsupported in vLLM V1. The way conv and ssm state is managed for V0 is a bit of a hack and I think we'll want to engineer it the right way for V1.

heheda12345 is working on a series of PRs to manage the memory allocator, which looks to address Mamba in the future #11382, so I'm hoping that this is the path to support Mamba models in V1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants