Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] AutoAWQ quantization - Cannot open ndarray-cache.json #1732

Closed
dusty-nv opened this issue Feb 8, 2024 · 7 comments
Closed

[Bug] AutoAWQ quantization - Cannot open ndarray-cache.json #1732

dusty-nv opened this issue Feb 8, 2024 · 7 comments
Labels
bug Confirmed bugs

Comments

@dusty-nv
Copy link

dusty-nv commented Feb 8, 2024

🐛 Bug

When trying to run a llama-7B that was compiled with mlc_chat and q4f16_autoawq quantization, it never makes the ndarray-cache.json file and produces error when trying to load it at runtime:

Traceback (most recent call last):
  File "/opt/mlc-llm/benchmark.py", line 105, in <module>
    cm = ChatModule(model=args.model, model_lib_path=args.model_lib_path, chat_config=cfg)
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/chat_module.py", line 780, in __init__
    self._reload(self.model_lib_path, self.model_path, user_chat_config_json_str)
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/chat_module.py", line 1008, in _reload
    self._reload_func(lib, model_path, app_config_json, kv_cache_config.asjson())
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 277, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  [bt] (8) /usr/local/lib/python3.10/dist-packages/mlc_chat/libmlc_llm_module.so(+0x1d3a84) [0xffff25143a84]
  [bt] (7) /usr/local/lib/python3.10/dist-packages/mlc_chat/libmlc_llm_module.so(mlc::llm::LLMChat::Reload(tvm::runtime::TVMArgValue, tvm::runtime::String, tvm::runtime::String, tvm::runtime::String)+0x5c4) [0xffff251421e4]
  [bt] (6) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<void (std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, int, int)>::AssignTypedLambda<void (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, int, int)>(void (*)(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, int, int), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)+0x18c) [0xffff5fdfbdfc]
  [bt] (5) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::relax_vm::NDArrayCache::Load(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, int, int)+0x50) [0xffff5fdfbec0]
  [bt] (4) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::relax_vm::NDArrayCacheMetadata::Load(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)+0x9c) [0xffff5fdf5ba0]
  [bt] (3) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::LoadBinaryFromFile(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >*)+0x25c) [0xffff5fd895a0]
  [bt] (2) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x3068748) [0xffff5fd88748]
  [bt] (1) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x68) [0xffff5df436a8]
  [bt] (0) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x30) [0xffff5fd93230]
  File "/opt/mlc-llm/3rdparty/tvm/src/runtime/file_utils.cc", line 121
InternalError: Check failed: (!fs.fail()) is false: Cannot open /data/models/mlc/16aaa30/slim/Llama-2-7b-chat-hf-awq-q4f16_autoawq/ndarray-cache.json

This is the script that I used to convert the original Llama-2-7b-chat-hf model to AWQ format first (using AutoAWQ library)

#!/usr/bin/env python3
import os
import pprint
import argparse

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

parser = argparse.ArgumentParser()

parser.add_argument('--model', type=str, default='meta-llama/Llama-2-7b-chat-hf')
parser.add_argument('--output', type=str, default=None, help='output directory to save the quantized model (optional)')
parser.add_argument('--kernel', type=str, default='GEMM', choices=['GEMM', 'GEMV'], help='AWQ kernel to use for quantization')

args = parser.parse_args()
print(args)

# load model
print(f"-- loading {args.model}")

model = AutoAWQForCausalLM.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)

# quantize
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": args.kernel }

print(f"-- quantizing {args.model} with AWQ config:")
pprint.pprint(quant_config)

model.quantize(tokenizer, quant_config=quant_config)

# save quantized model
if args.output:
    print(f"-- saving quantized model to {args.output}")
    model.save_quantized(args.output)
    tokenizer.save_pretrained(args.output)

And then I used mlc_chat convert_weight/gen_config/compile with --quantization=q4f16_autoawq, which was successful. However it doesn't load at runtime due to the aforementioned issue. I didn't find docs for mlc_chat about using AWQ, so it's unclear if I missed any steps or what the model directory structure should be (or where the weights reside)

This is what the directory of the AutoAWQ model looks like:

ls models/Llama-2-7b-chat-hf-awq/
config.json  generation_config.json  model.safetensors  model.safetensors.index.json  quant_config.json  special_tokens_map.json  tokenizer.json  tokenizer.model  tokenizer_config.json

And this is what the directory of the compiled model from mlc_chat looks like:

ls Llama-2-7b-chat-hf-awq-q4f16_autoawq/
Llama-2-7b-chat-hf-awq-q4f16_autoawq-cuda.so  mlc-chat-config.json  tokenizer.json  tokenizer.model  tokenizer_config.json

So unlike the models compiled with q4f16_1 or q4f16_ft quantization, the q4f16_autoawq model doesn't have the weights within it or that ndarray-cache.json file. Is there something I am doing wrong?

Environment

  • Platform (e.g. WebGPU/Vulkan/IOS/Android/CUDA): ARM64 + CUDA
  • Operating system (e.g. Ubuntu/Windows/MacOS/...): Ubuntu 22.04
  • Device (e.g. iPhone 12 Pro, PC+RTX 3090, ...) Jetson AGX Orin
  • How you installed MLC-LLM (conda, source): source 16aaa30
  • How you installed TVM-Unity (pip, source): source https://github.com/mlc-ai/relax/tree/292137088115ac81779607ca223bbbd9ad40cb55
  • Python version (e.g. 3.10): 3.10
  • GPU driver version (if applicable): JetPack 6.0
  • CUDA/cuDNN version (if applicable): CUDA 12.2
  • TVM Unity Hash Tag (python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))", applicable if you compile models):
  • Any other relevant information:
@dusty-nv dusty-nv added the bug Confirmed bugs label Feb 8, 2024
@CharlieFRuan
Copy link
Contributor

I believe for AWQ you'd still need to go through mlc_chat convert_weight just like the other quantization; there are some steps here: #1229; let me know how it goes. The commands may have some slight difference now since that PR has been out for a bit.

I didn't find docs for mlc_chat about using AWQ

This is mainly because these quantization are somewhat experimental in MLC. I will add the steps to the docs!

@dusty-nv
Copy link
Author

OK, tried model quantized with llm-awq instead (like in #1229), but get an error during convert_weight stage:

$ mlc_chat convert_weight \
  models/Llama-2-7b-chat-hf  \
  --quantization q4f16_autoawq \
  --source /data/models/awq/Llama-2-7b-chat-hf/w4-g128-awq.pt \
  --source-format awq \
  --output Llama-2-7b-chat-hf-q4f16_autoawq

[2024-02-10 05:46:10] INFO auto_config.py:115: Found model configuration: models/Llama-2-7b-chat-hf/config.json
[2024-02-10 05:46:11] INFO auto_device.py:76: Found device: cuda:0
[2024-02-10 05:46:12] INFO auto_device.py:85: Not found device: rocm:0
[2024-02-10 05:46:13] INFO auto_device.py:85: Not found device: metal:0
[2024-02-10 05:46:14] INFO auto_device.py:85: Not found device: vulkan:0
[2024-02-10 05:46:15] INFO auto_device.py:85: Not found device: opencl:0
[2024-02-10 05:46:15] INFO auto_device.py:33: Using device: cuda:0
[2024-02-10 05:46:15] INFO auto_weight.py:70: Finding weights in: /data/models/awq/Llama-2-7b-chat-hf/w4-g128-awq.pt
[2024-02-10 05:46:15] INFO auto_config.py:153: Found model type: llama. Use `--model-type` to override.
Weight conversion with arguments:
  --config          models/Llama-2-7b-chat-hf/config.json
  --quantization    AWQQuantize(name='q4f16_autoawq', kind='awq', group_size=128, quantize_dtype='int4', storage_dtype='uint32', model_dtype='float16', num_elem_per_storage=8, num_storage_per_group=16, max_int_value=7, prebuilt_quantize_func={})
  --model-type      llama
  --device          cuda:0
  --source          /data/models/awq/Llama-2-7b-chat-hf/w4-g128-awq.pt
  --source-format   awq
  --output          Llama-2-7b-chat-hf-q4f16_autoawq
[2024-02-10 05:46:15] INFO llama_model.py:51: context_window_size not found in config.json. Falling back to max_position_embeddings (4096)
[2024-02-10 05:46:15] INFO llama_model.py:71: prefill_chunk_size defaults to context_window_size (4096)
[2024-02-10 05:46:23] INFO huggingface_loader.py:182: Loading HF parameters from: /data/models/awq/Llama-2-7b-chat-hf/w4-g128-awq.pt
[2024-02-10 05:46:27] INFO huggingface_loader.py:172: [Not quantized] Parameter: "model.embed_tokens.weight", shape: (32000, 4096), dtype: float16
[2024-02-10 05:46:27] INFO huggingface_loader.py:172: [Not quantized] Parameter: "model.layers.0.self_attn.qkv_proj.qweight", shape: (4096, 1536), dtype: uint32
[2024-02-10 05:46:27] INFO huggingface_loader.py:172: [Not quantized] Parameter: "model.layers.0.self_attn.qkv_proj.qzeros", shape: (4096, 12), dtype: uint32
  0% 2/451 [00:00<01:34,  4.73it/s]
Traceback (most recent call last):
  File "/usr/local/bin/mlc_chat", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/__main__.py", line 28, in main
    cli.main(sys.argv[2:])
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/cli/convert_weight.py", line 87, in main
    convert_weight(
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/interface/convert_weight.py", line 169, in convert_weight
    _convert_args(args)
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/interface/convert_weight.py", line 124, in _convert_args
    _check_param(name, param)
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/interface/convert_weight.py", line 102, in _check_param
    raise ValueError(
ValueError: Parameter model.layers.0.self_attn.qkv_proj.qzeros has shape (4096, 12), but expected [32, 1536]

Are you able to run it? I can provide the AWQ model if that's helpful.

@CharlieFRuan
Copy link
Contributor

Ahh apologies @dusty-nv, I pointed you to the wrong place:((

I followed this flow (which uses AutoAWQ and an existing repo from TheBloke) and it worked: #1362. Let me know if it works for you!

I will update the doc to put together instructions for running AWQ and FT.

@dusty-nv
Copy link
Author

Aha thanks @CharlieFRuan!, I am able to run the flow from #1362 on TheBloke/Llama-2-7B-Chat-AWQ (and also my own copy exported from AutoAWQ in #1732 (comment) above).

However it is very slow, only ~13.7 tokens/sec output (whereas Llama-2-7b with q4f16_ft quantization produces 45 tokens/sec). Meanwhile the original llm_awq project runs at ~30 tokens/sec on the same hardware. Any ideas what may be happening?

@CharlieFRuan
Copy link
Contributor

I see! Thanks for the report; we'll test the degradation this week!

@CharlieFRuan
Copy link
Contributor

@dusty-nv I think for AWQ's layout we indeed haven't optimized it well (unlike FasterTransformer which was due to imparity with the old flow). We will come back in the future to optimize it.

@bethalianovike
Copy link

Hi @dusty-nv
I tried to use your above script to convert the original Llama-2-7b-chat-hf model to AWQ format, but I encountered this error, can you give some advice on this issue? Thank you in advance!

$ python3 convert_hf_to_awq.py
Namespace(model='meta-llama/Llama-2-7b-chat-hf', output=None, kernel='GEMM')
-- loading meta-llama/Llama-2-7b-chat-hf
Fetching 13 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 99138.09it/s]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.15it/s]
-- quantizing meta-llama/Llama-2-7b-chat-hf with AWQ config:
{'q_group_size': 128, 'version': 'GEMM', 'w_bit': 4, 'zero_point': True}
Repo card metadata block was not found. Setting CardData to empty.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
AWQ:   3%|████▍                                                                                                                                          | 1/32 [00:19<09:49, 19.02s/it]
Traceback (most recent call last):
  File "/home/convert_hf_to_awq.py", line 31, in <module>
    model.quantize(tokenizer, quant_config=quant_config)
  File "/home/miniconda3/envs/mlc-chat-test/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/mlc-chat-test/lib/python3.11/site-packages/awq/models/base.py", line 186, in quantize
    self.quantizer.quantize()
  File "/home/miniconda3/envs/mlc-chat-test/lib/python3.11/site-packages/awq/quantize/quantizer.py", line 156, in quantize
    scales_list = [
                  ^
  File "/home/miniconda3/envs/mlc-chat-test/lib/python3.11/site-packages/awq/quantize/quantizer.py", line 157, in <listcomp>
    self._search_best_scale(self.modules[i], **layer)
  File "/home/miniconda3/envs/mlc-chat-test/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/mlc-chat-test/lib/python3.11/site-packages/awq/quantize/quantizer.py", line 277, in _search_best_scale
    best_scales = self._compute_best_scale(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/mlc-chat-test/lib/python3.11/site-packages/awq/quantize/quantizer.py", line 334, in _compute_best_scale
    self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/miniconda3/envs/mlc-chat-test/lib/python3.11/site-packages/awq/quantize/quantizer.py", line 69, in pseudo_quantize_tensor
    assert torch.isnan(w).sum() == 0
           ^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Confirmed bugs
Projects
None yet
Development

No branches or pull requests

4 participants