Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] SLM - mlc_chat convert_weight has errors with q4f16_ft quantization #1723

Closed
dusty-nv opened this issue Feb 7, 2024 · 14 comments
Closed
Assignees
Labels
bug Confirmed bugs

Comments

@dusty-nv
Copy link

dusty-nv commented Feb 7, 2024

🐛 Bug

When using the mlc_chat convert_weight method on llama-2-7b-chat-hf with q4f16_ft quantization, it throws the following exception:

$ mlc_chat convert_weight models/Llama-2-7b-chat-hf  --quantization q4f16_ft --output Llama-2-7b-chat-hf-q4f16_ft

[2024-02-07 17:57:15] INFO auto_config.py:115: Found model configuration: models/Llama-2-7b-chat-hf/config.json
[2024-02-07 17:57:17] INFO auto_device.py:76: Found device: cuda:0
[2024-02-07 17:57:17] INFO auto_device.py:85: Not found device: rocm:0
[2024-02-07 17:57:18] INFO auto_device.py:85: Not found device: metal:0
[2024-02-07 17:57:19] INFO auto_device.py:85: Not found device: vulkan:0
[2024-02-07 17:57:20] INFO auto_device.py:85: Not found device: opencl:0
[2024-02-07 17:57:20] INFO auto_device.py:33: Using device: cuda:0
[2024-02-07 17:57:20] INFO auto_weight.py:70: Finding weights in: models/Llama-2-7b-chat-hf
[2024-02-07 17:57:20] INFO auto_weight.py:120: Found source weight format: huggingface-torch. Source configuration: models/Llama-2-7b-chat-hf/pytorch_model.bin.index.json
[2024-02-07 17:57:20] INFO auto_weight.py:143: Found source weight format: huggingface-safetensor. Source configuration: models/Llama-2-7b-chat-hf/model.safetensors.index.json
[2024-02-07 17:57:20] INFO auto_weight.py:106: Using source weight configuration: models/Llama-2-7b-chat-hf/pytorch_model.bin.index.json. Use `--source` to override.
[2024-02-07 17:57:20] INFO auto_weight.py:110: Using source weight format: huggingface-torch. Use `--source-format` to override.
[2024-02-07 17:57:20] INFO auto_config.py:153: Found model type: llama. Use `--model-type` to override.
Weight conversion with arguments:
  --config          models/Llama-2-7b-chat-hf/config.json
  --quantization    FTQuantize(name='q4f16_ft', kind='ft-quant', quantize_dtype='int4', storage_dtype='int8', model_dtype='float16', group_size=None, num_elem_per_storage=2, max_int_value=7)
  --model-type      llama
  --device          cuda:0
  --source          models/Llama-2-7b-chat-hf/pytorch_model.bin.index.json
  --source-format   huggingface-torch
  --output          Llama-2-7b-chat-hf-q4f16_ft
[2024-02-07 17:57:20] INFO llama_model.py:51: context_window_size not found in config.json. Falling back to max_position_embeddings (4096)
[2024-02-07 17:57:20] INFO llama_model.py:71: prefill_chunk_size defaults to context_window_size (4096)
Traceback (most recent call last):
  File "/usr/local/bin/mlc_chat", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/__main__.py", line 28, in main
    cli.main(sys.argv[2:])
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/cli/convert_weight.py", line 87, in main
    convert_weight(
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/interface/convert_weight.py", line 169, in convert_weight
    _convert_args(args)
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/interface/convert_weight.py", line 68, in _convert_args
    model, quantize_map = args.model.quantize[args.quantization.kind](
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/model/llama/llama_quantization.py", line 37, in ft_quant
    model = quantization.quantize_model(
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/quantization/ft_quantization.py", line 134, in quantize_model
    model = mutator.visit(name_prefix, model)
  File "/usr/local/lib/python3.10/dist-packages/tvm/relax/frontend/nn/visitor.py", line 140, in visit
    setattr(node, key, self.visit_module(_get_child_name(name, key), value))
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/quantization/ft_quantization.py", line 123, in visit_module
    return FTQuantizeLinear.from_linear(node, self.config)
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/quantization/ft_quantization.py", line 329, in from_linear
    quantized_linear = FTQuantizeLinear(
  File "/usr/local/lib/python3.10/dist-packages/mlc_chat/quantization/ft_quantization.py", line 298, in __init__
    (in_features, tir.ceildiv(out_features, config.num_elem_per_storage)),
  File "/usr/local/lib/python3.10/dist-packages/tvm/tir/op.py", line 3053, in ceildiv
    return _ffi_api._OpCeilDiv(lhs, rhs, span)  # type: ignore
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/usr/local/lib/python3.10/dist-packages/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (7) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(TVMFuncCall+0x68) [0xffff606fac08]
  [bt] (6) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x1ff1bb4) [0xffff5f6f1bb4]
  [bt] (5) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::TVMMovableArgValueWithContext_::operator tvm::PrimExpr<tvm::PrimExpr>() const+0x94) [0xffff5e8fc5e4]
  [bt] (4) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::PackedFuncValueConverter<tvm::PrimExpr>::From(tvm::runtime::TVMPODValue_ const&)+0x4c) [0xffff5e8efafc]
  [bt] (3) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::ObjectRef tvm::runtime::TVMPODValue_::AsObjectRef<tvm::runtime::ObjectRef>() const+0x254) [0xffff5e8efa84]
  [bt] (2) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(+0x11e7588) [0xffff5e8e7588]
  [bt] (1) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::detail::LogFatal::Entry::Finalize()+0x68) [0xffff5e8ece78]
  [bt] (0) /usr/local/lib/python3.10/dist-packages/tvm/libtvm.so(tvm::runtime::Backtrace[abi:cxx11]()+0x30) [0xffff60742820]
  File "/opt/mlc-llm/3rdparty/tvm/include/tvm/runtime/packed_func.h", line 785
TVMError: In function tir._OpCeilDiv(0: PrimExpr, 1: PrimExpr, 2: Span) -> PrimExpr: error while converting argument 0: [17:57:20] /opt/mlc-llm/3rdparty/tvm/include/tvm/runtime/packed_func.h:2080: InternalError: Check failed: type_code_ == kTVMObjectHandle (11 vs. 8) : expected Object but got str

The same model/install works with the previous mlc_llm.build method using q4f16_ft. And q4f16_1 quantization works with the SLM mlc_chat model builder, but later encounters issue #1551 (comment) during runtime.

Environment

  • Platform (e.g. WebGPU/Vulkan/IOS/Android/CUDA): CUDA
  • Operating system (e.g. Ubuntu/Windows/MacOS/...): Ubuntu 22.04
  • Device (e.g. iPhone 12 Pro, PC+RTX 3090, ...) Jetson AGX Orin
  • How you installed MLC-LLM (conda, source): Source (commit d840de5)
  • How you installed TVM-Unity (pip, source): Source
  • Python version (e.g. 3.10): 3.10
  • GPU driver version (if applicable): JetPack 6.0
  • CUDA/cuDNN version (if applicable): CUDA 12.2
  • TVM Unity Hash Tag (python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))", applicable if you compile models):
  • Any other relevant information:
@dusty-nv dusty-nv added the bug Confirmed bugs label Feb 7, 2024
@CharlieFRuan CharlieFRuan self-assigned this Feb 7, 2024
@CharlieFRuan
Copy link
Contributor

Hi @dusty-nv, thanks for reporting the issue! Was able to reproduce it on my end; I'm looking into it and will update a fix soon.

Meanwhile, as a workaround for q4f16_ft, you could manually change llama_model.py:

  • From self.lm_head = nn.Linear(config.hidden_size, "vocab_size", bias=False)
    to self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

  • From self.embed_tokens = nn.Embedding("vocab_size", config.hidden_size)
    to self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

Regarding the runtime issue, I commented on the other issue. If you are using prebuilt pip package for TVM, updating it should suffice. If you are building from source, updating to head and rebuilding it should work (to be safe, you can rm -rf build and follow these steps again).

@dusty-nv
Copy link
Author

dusty-nv commented Feb 8, 2024

Thanks @CharlieFRuan, will try this!

@CharlieFRuan
Copy link
Contributor

The issue for q4f16_ft is fixed via #1731 (i.e. you do not have to manually change the code now). Let us know if there are further issues!

@dusty-nv
Copy link
Author

dusty-nv commented Feb 9, 2024

Thanks @CharlieFRuan! q4f16_ft is working again with #1731.

However the q4f16_ft model built with mlc_chat benchmarks ~15% slower than the same one built with mlc_llm.build (measured using ChatModule.benchmark_generate()). For mlc_chat I'm compiling the model with --opt O3, and for mlc_llm.build with --use-cuda-graph --use-flash-attn-mqa

Do you see similar perf regressions with mlc_chat SLIM model compiler vs mlc_llm.build? It doesn't seem specific to q4f16_ft quantization, the other methods compare similarly with the newer SLIM way.

@CharlieFRuan
Copy link
Contributor

Really appreciate reporting the regression @dusty-nv! We'll look into this.

@srikanthsrnvs
Copy link

srikanthsrnvs commented Feb 14, 2024

Bumping on this! We also notice several performance hits. Another issue is that the Mixtral model build process fails on the same environment (Jetson Orin AARCH64)

MasterJH5574 pushed a commit that referenced this issue Feb 17, 2024
This PR reverts the change introduced in #1731 where we skip quantizing `lm_head` to avoid the dynamic vocab issue.

This led to performance degradation as pointed out in #1723.

Instead, we fall back to `GroupQuantizeLinear` for `lm_head`, which preserves performance and avoids the dynamic vocab size issue.

Performance change on RTX 4090

Prefill: 
throughput: **950.792 --> 973.859 tok/s**
total tokens: 7 tok

Decode:
throughput: **214.372 --> 223.491 tok/s**
total tokens: 256 tok
@CharlieFRuan
Copy link
Contributor

Hi @dusty-nv @srikanthsrnvs! With these 2 PRs that just got merged:

We were able to increase decode from 214 tok/s to ~227 tok/s on RTX 4090 for Llama 2 7B q4f16_ft.

We aligned most of the performance, but there is still a minor gap with the old flow due to SLM currently not having CUDA graph support. We will follow up with that.

@dusty-nv
Copy link
Author

dusty-nv commented Feb 18, 2024

Thanks @CharlieFRuan, I've confirmed with MLC @ 6cf63bb that the performance has improved! This was on Jetson AGX Orin:

  • MLC @ 6cf63bb, mlc_chat compile, llama-2-7b-q4f16_ft -> 45 tokens/sec
  • MLC @ 3feed05, mlc_llm.build, llama-2-7b-q4f16_ft -> 47 tokens/sec

So yea, that is close, thank you! Please update this topic when the CUDA graphs have been added to mlc_chat.build

p.s. in this latest build I have started having errors with mlc_llm.build (#1779) so hope we can resolve the open issues with migrating to mlc_chat compile soon 👍

@srikanthsrnvs
Copy link

Thanks @CharlieFRuan

Im curious to know what performance can we expect from Mistral 8x7B on an Orin vs Mistral-7B. The Mixtral model gives me performance closer to a 13B model when using non-mlc implementations. However right now, I am just getting 7 tokens/s on Mixtral, vs 45 tokens/s with a 7B model.

Curious if the Cuda graph is the reason behind this, but I suspect that is just a minor 5% improvement. Any ideas?

@dusty-nv perhaps you've tested out Mixtral as well?

@srikanthsrnvs
Copy link

Hi @CharlieFRuan just bumping on this, is Mixtral fully optimized on MLC? or are there optimizations to be made still? Would love it if you could point me in the direction so I could take a look myself!

@CharlieFRuan
Copy link
Contributor

Hi, @srikanthsrnvs! Unfortunately, I am not too familiar with Mixtral in MLC. Feel free to open a separate issue for this and people with more knowledge on this could respond. Thank you!!

@MasterJH5574
Copy link
Member

just bumping on this, is Mixtral fully optimized on MLC? or are there optimizations to be made still? Would love it if you could point me in the direction so I could take a look myself!

Hi @srikanthsrnvs, Mixtral on CUDA should be mostly optimized. Here is another piece we observed that might be helpful for performance recently #1866 -- we can leverage thrust sort for the expert selection. After this getting merged, Mixtral should be well optimized I think. We are happy to dig more if there is any perf regression or there are some perf issue popping up.

@srikanthsrnvs
Copy link

just bumping on this, is Mixtral fully optimized on MLC? or are there optimizations to be made still? Would love it if you could point me in the direction so I could take a look myself!

Hi @srikanthsrnvs, Mixtral on CUDA should be mostly optimized. Here is another piece we observed that might be helpful for performance recently #1866 -- we can leverage thrust sort for the expert selection. After this getting merged, Mixtral should be well optimized I think. We are happy to dig more if there is any perf regression or there are some perf issue popping up.

How do I actually use Thrust? #1866 (comment)

@MasterJH5574
Copy link
Member

Thank you @srikanthsrnvs for following up. Given thurst is not the focus of this issue, I'm gonna close this one. We can follow up the discussion in 1866

smickey040404 added a commit to smickey040404/mlc-llm that referenced this issue Feb 11, 2025
This PR reverts the change introduced in mlc-ai/mlc-llm#1731 where we skip quantizing `lm_head` to avoid the dynamic vocab issue.

This led to performance degradation as pointed out in mlc-ai/mlc-llm#1723.

Instead, we fall back to `GroupQuantizeLinear` for `lm_head`, which preserves performance and avoids the dynamic vocab size issue.

Performance change on RTX 4090

Prefill: 
throughput: **950.792 --> 973.859 tok/s**
total tokens: 7 tok

Decode:
throughput: **214.372 --> 223.491 tok/s**
total tokens: 256 tok
tristankincaid added a commit to tristankincaid/mlc-llm that referenced this issue Feb 16, 2025
This PR reverts the change introduced in mlc-ai/mlc-llm#1731 where we skip quantizing `lm_head` to avoid the dynamic vocab issue.

This led to performance degradation as pointed out in mlc-ai/mlc-llm#1723.

Instead, we fall back to `GroupQuantizeLinear` for `lm_head`, which preserves performance and avoids the dynamic vocab size issue.

Performance change on RTX 4090

Prefill: 
throughput: **950.792 --> 973.859 tok/s**
total tokens: 7 tok

Decode:
throughput: **214.372 --> 223.491 tok/s**
total tokens: 256 tok
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Confirmed bugs
Projects
None yet
Development

No branches or pull requests

4 participants