Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Try-catch conditions are incorrect to import correct ROCm Flash Attention Backend in Draft Model #9100

Closed
1 task done
tjtanaa opened this issue Oct 6, 2024 · 0 comments · Fixed by #9101
Closed
1 task done
Labels
bug Something isn't working

Comments

@tjtanaa
Copy link
Contributor

tjtanaa commented Oct 6, 2024

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

Model Input Dumps

No response

🐛 Describe the bug

I found an issue running draft model speculative decoding on AMD platform, the issue arised from vllm/spec_decode/draft_model_runner.py

try:
    from vllm.attention.backends.flash_attn import FlashAttentionMetadata # this is throwing ImportError rather than ModuleNotFoundError
except ModuleNotFoundError:
    # vllm_flash_attn is not installed, use the identical ROCm FA metadata
    from vllm.attention.backends.rocm_flash_attn import (
        ROCmFlashAttentionMetadata as FlashAttentionMetadata)

Within the try-catch block ImportError is thrown rather than ModuleNotFoundError

  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/engine/multiprocessing/engine.py", line 78, in __init__                          
    self.engine = LLMEngine(*args,                                                                                                          
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/engine/llm_engine.py", line 335, in __init__                                     
    self.model_executor = executor_class(                                                                                                   
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/distributed_gpu_executor.py", line 26, in __init__                      
    super().__init__(*args, **kwargs)                                                                                                       
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/executor_base.py", line 47, in __init__                                 
    self._init_executor()                                                                                                                   
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/multiproc_gpu_executor.py", line 108, in _init_executor                 
    self.driver_worker = self._create_worker(                                                                                               
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/gpu_executor.py", line 105, in _create_worker                           
    return create_worker(**self._get_create_worker_kwargs(                                                                                  
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/executor/gpu_executor.py", line 24, in create_worker                             
    wrapper.init_worker(**kwargs)                                                                                                           
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/worker/worker_base.py", line 446, in init_worker                                 
    mod = importlib.import_module(self.worker_module_name)                                                                                  
  File "/home/aac/anaconda3/envs/rocm611-0929/lib/python3.9/importlib/__init__.py", line 127, in import_module                              
    return _bootstrap._gcd_import(name[level:], package, level)                                                                             
  File "<frozen importlib._bootstrap>", line 1030, in _gcd_import                                                                           
  File "<frozen importlib._bootstrap>", line 1007, in _find_and_load                                                                        
  File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked                                                                
  File "<frozen importlib._bootstrap>", line 680, in _load_unlocked                                                                         
  File "<frozen importlib._bootstrap_external>", line 850, in exec_module                                                                   
  File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed                                                              
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/spec_decode/spec_decode_worker.py", line 21, in <module>                         
    from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/spec_decode/draft_model_runner.py", line 9, in <module>
    from vllm.attention.backends.flash_attn import FlashAttentionMetadata
  File "/home/aac/apps/rocm611-0929/vllm-fix-spec-amd/vllm/attention/backends/flash_attn.py", line 23, in <module>
    from vllm.vllm_flash_attn import (flash_attn_varlen_func,                                                                               
ImportError: cannot import name 'flash_attn_varlen_func' from 'vllm.vllm_flash_attn' (unknown location)

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@tjtanaa tjtanaa added the bug Something isn't working label Oct 6, 2024
@tjtanaa tjtanaa changed the title [Bug]: Try-catch conditions are incorrect to import correct Flash Attention Backend in Draft Model [Bug]: Try-catch conditions are incorrect to import correct Flash Attention Backend (ROCm or CUDA) in Draft Model Oct 6, 2024
@tjtanaa tjtanaa changed the title [Bug]: Try-catch conditions are incorrect to import correct Flash Attention Backend (ROCm or CUDA) in Draft Model [Bug]: Try-catch conditions are incorrect to import correct ROCm Flash Attention Backend in Draft Model Oct 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
1 participant