Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixtral support #1529

Merged
merged 1 commit into from
Jan 8, 2024
Merged

Mixtral support #1529

merged 1 commit into from
Jan 8, 2024

Conversation

jinhongyii
Copy link
Member

@jinhongyii jinhongyii commented Jan 2, 2024

This PR introduces support for Mixtral MoE models with MLC's latest SLM
quantization/compilation pipeline. It includes the following pieces of
changes:

Operators. We implemented a list of operators in TIR's TVMScript
format in two files moe_misc and moe_matmul. Those TIR kernels
implement "transpose indices" and "blocked-CSR-COO" as described in
MegaBlock [1].

moe_misc.py primarily concerns sparsity-related operators, including:

  • get_indices, get_indptr and scatter_output: CSR-style index
    manipulation and array shuffling that makes the input ranges each
    expert has to deal with contiguous.
  • moe_sum, moe_cumsum, topk which are standard operators but
    specialized for MoE usecases, e.g. #experts and #activated-experts are
    small.

moe_matmul.py includes non-quantized and quantized GEMV and GEMV
operators used in MoE model serving. Typically, in single batch
decoding, GEMV operators should suffice, but group GEMM is a necessary
dependency in both prefilling and batched decoding.

Model architecture. We reuse the attention blocking block from
Mistral, and implemented MLP MoE in mixtral_model.py. In Mixtral,
there are three groups of experts in each MLP, where e1 and e3 are
gate/up projections (project-in) and e2 is down project (project-out).

Weight quantization. We batch all experts of the same kind into a
single tensor, whose shape is (Ne, N, K), where Ne is the total
number of experts, N is out features and K is in-features. Applying
group quantization, we compress along the K dimension as consistent
with the rest of the project.

Performance. The current TIR is highly optimized for non-tensor core
scenarios (Metal, WebGPU, non-TensorCore CUDA, AMD, etc) and tensor core
performance is left for a PR in the nearest future.

Try out MLC's Mixtral Model. The int4-quantized Mixtral model has
24.5G of parameters.

from mlc_chat import ChatConfig, ChatModule, callback
from mlc_chat.support import logging
logging.enable_logging()

MODEL = "HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"
NUM_GPU = 1

def main():
    cm = ChatModule(MODEL, device="cuda:0", chat_config=ChatConfig(
        sliding_window_size=1024,
        tensor_parallel_shards=NUM_GPU,
    ))
    cm.generate("What is the meaning of life?", progress_callback=callback.StreamToStdout(callback_interval=2))

if __name__ == "__main__":
    main()

Quantization formats:

The 3-bit version can be run comfortably using a 24G GPU (e.g. 4090,
3090Ti).

Convert Mixtral to MLC format from scratch. The following instructions
are only needed for advanced users to quantize Mixtral from scratch.

SRC_DIR=/path/to/Mixtral-8x7B-v0.1 # raw model downloaded from HuggingFace
MODEL_DIR=/mlc_models/mixtral-q4f16_1 # destination directory

mlc_chat gen_config $SRC_DIR -o $MODEL_DIR --quantization q4f16_1 \
  --conv-template LM  # "LM" (lang model) means no conversation template yet

mlc_chat convert_weight $SRC_DIR --quantization q4f16_1 -o $MODEL_DIR

[1] Gale, Trevor, Deepak Narayanan, Cliff Young, and Matei Zaharia.
"MegaBlocks: Efficient Sparse Training with Mixture-of-Experts."
Proceedings of MLSys 2023.

Co-authored-by: Junru Shao <[email protected]>

@junrushao junrushao force-pushed the mixtral-final branch 4 times, most recently from 63d5dab to 2bbd9d7 Compare January 5, 2024 06:04
@junrushao junrushao force-pushed the mixtral-final branch 10 times, most recently from 1ac89bc to a4cfb6b Compare January 6, 2024 23:24
@junrushao junrushao force-pushed the mixtral-final branch 5 times, most recently from 6ac164a to 59dd302 Compare January 7, 2024 00:39
@junrushao junrushao force-pushed the mixtral-final branch 9 times, most recently from 5243308 to 9bfadbb Compare January 8, 2024 05:35
This PR introduces support for Mixtral MoE models with MLC's latest SLM
quantization/compilation pipeline. It includes the following pieces of
changes:

**Operators.** We implemented a list of operators in TIR's TVMScript
format in two files `moe_misc` and `moe_matmul`. Those TIR kernels
implement "transpose indices" and "blocked-CSR-COO" as described in
MegaBlock [1].

`moe_misc.py` primarily concerns sparsity-related operators, including:
- `get_indices`, `get_indptr` and `scatter_output`: CSR-style index
  manipulation and array shuffling that makes the input ranges each
  expert has to deal with contiguous.
- `moe_sum`, `moe_cumsum`, `topk` which are standard operators but
  specialized for MoE usecases, e.g. #experts and #activated-experts are
  small.

`moe_matmul.py` includes non-quantized and quantized GEMV and GEMV
operators used in MoE model serving. Typically, in single batch
decoding, GEMV operators should suffice, but group GEMM is a necessary
dependency in both prefilling and batched decoding.

**Model architecture.** We reuse the attention blocking block from
Mistral, and implemented MLP MoE in `mixtral_model.py`. In Mixtral,
there are three groups of experts in each MLP, where `e1` and `e3` are
gate/up projections (project-in) and `e2` is down project (project-out).

**Weight quantization.** We batch all experts of the same kind into a
single tensor, whose shape is `(Ne, N, K)`, where `Ne` is the total
number of experts, `N` is out features and `K` is in-features. Applying
group quantization, we compress along the `K` dimension as consistent
with the rest of the project.

**Performance.** The current TIR is highly optimized for non-tensor core
scenarios (Metal, WebGPU, non-TensorCore CUDA, AMD, etc) and tensor core
performance is left for a PR in the nearest future.

**Try out MLC's Mixtral Model.** The int4-quantized Mixtral model has
24.5G of parameters.

```python
from mlc_chat import ChatConfig, ChatModule, callback
from mlc_chat.support import logging
logging.enable_logging()

MODEL = "HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"
NUM_GPU = 1

def main():
    cm = ChatModule(MODEL, device="cuda:0", chat_config=ChatConfig(
        sliding_window_size=1024,
        tensor_parallel_shards=NUM_GPU,
    ))
    cm.generate("What is the meaning of life?", progress_callback=callback.StreamToStdout(callback_interval=2))

if __name__ == "__main__":
    main()
```

Quantization formats:
- 3-bit (19.662 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC)
- 4-bit (24.466 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC)

The 3-bit version can be run comfortably using a 24G GPU (e.g. 4090,
3090Ti).

**Convert Mixtral to MLC format from scratch.** The following instructions
are only needed for advanced users to quantize Mixtral from scratch.

```bash
SRC_DIR=/path/to/Mixtral-8x7B-v0.1 # raw model downloaded from HuggingFace
MODEL_DIR=/mlc_models/mixtral-q4f16_1 # destination directory

mlc_chat gen_config $SRC_DIR -o $MODEL_DIR --quantization q4f16_1 \
  --conv-template LM  # "LM" (lang model) means no conversation template yet

mlc_chat convert_weight $SRC_DIR --quantization q4f16_1 -o $MODEL_DIR
```

[1] Gale, Trevor, Deepak Narayanan, Cliff Young, and Matei Zaharia.
"MegaBlocks: Efficient Sparse Training with Mixture-of-Experts."
Proceedings of MLSys 2023.

Co-authored-by: Junru Shao <[email protected]>
Copy link
Contributor

@LeshengJin LeshengJin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @jinhongyii @junrushao

@LeshengJin LeshengJin merged commit e32c6c9 into mlc-ai:main Jan 8, 2024
junrushao added a commit to junrushao/mlc-llm that referenced this pull request Jan 8, 2024
A follow-up of my previous PR (mlc-ai#1529).

This PR makes Mixtral work on Metal GPUs that macOS comes with. There
are honestly no much change needed, except for that Metal doesn't
support fp64 data types.

A python script to run Mixtral:

```python
from mlc_chat import ChatConfig, ChatModule, callback
from mlc_chat.support import logging
logging.enable_logging()

MODEL = "HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"
NUM_GPU = 1

def main():
    cm = ChatModule(MODEL, chat_config=ChatConfig(
        sliding_window_size=1024,
        tensor_parallel_shards=NUM_GPU,
    ))
    cm.generate("What is the meaning of life?", progress_callback=callback.StreamToStdout(callback_interval=2))

if __name__ == "__main__":
    main()
```

Quantization formats:
- 3-bit (19.662 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC)
- 4-bit (24.466 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC)
@junrushao
Copy link
Member

A follow-up PR: #1558

jinhongyii pushed a commit that referenced this pull request Jan 8, 2024
A follow-up of my previous PR (#1529).

This PR makes Mixtral work on Metal GPUs that macOS comes with. There
are honestly no much change needed, except for that Metal doesn't
support fp64 data types.

A python script to run Mixtral:

```python
from mlc_chat import ChatConfig, ChatModule, callback
from mlc_chat.support import logging
logging.enable_logging()

MODEL = "HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"
NUM_GPU = 1

def main():
    cm = ChatModule(MODEL, chat_config=ChatConfig(
        sliding_window_size=1024,
        tensor_parallel_shards=NUM_GPU,
    ))
    cm.generate("What is the meaning of life?", progress_callback=callback.StreamToStdout(callback_interval=2))

if __name__ == "__main__":
    main()
```

Quantization formats:
- 3-bit (19.662 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC)
- 4-bit (24.466 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC)
@junrushao
Copy link
Member

NOTE: this may take a few extra days until 4 outstanding PRs in TVM get merged. For those who are curious, I have a working branch of TVM if you'd love to build it from source: https://github.com/junrushao/tvm/commits/mixtral-debug/

smickey040404 added a commit to smickey040404/mlc-llm that referenced this pull request Feb 11, 2025
A follow-up of my previous PR (mlc-ai/mlc-llm#1529).

This PR makes Mixtral work on Metal GPUs that macOS comes with. There
are honestly no much change needed, except for that Metal doesn't
support fp64 data types.

A python script to run Mixtral:

```python
from mlc_chat import ChatConfig, ChatModule, callback
from mlc_chat.support import logging
logging.enable_logging()

MODEL = "HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"
NUM_GPU = 1

def main():
    cm = ChatModule(MODEL, chat_config=ChatConfig(
        sliding_window_size=1024,
        tensor_parallel_shards=NUM_GPU,
    ))
    cm.generate("What is the meaning of life?", progress_callback=callback.StreamToStdout(callback_interval=2))

if __name__ == "__main__":
    main()
```

Quantization formats:
- 3-bit (19.662 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC)
- 4-bit (24.466 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC)
tristankincaid added a commit to tristankincaid/mlc-llm that referenced this pull request Feb 16, 2025
A follow-up of my previous PR (mlc-ai/mlc-llm#1529).

This PR makes Mixtral work on Metal GPUs that macOS comes with. There
are honestly no much change needed, except for that Metal doesn't
support fp64 data types.

A python script to run Mixtral:

```python
from mlc_chat import ChatConfig, ChatModule, callback
from mlc_chat.support import logging
logging.enable_logging()

MODEL = "HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"
NUM_GPU = 1

def main():
    cm = ChatModule(MODEL, chat_config=ChatConfig(
        sliding_window_size=1024,
        tensor_parallel_shards=NUM_GPU,
    ))
    cm.generate("What is the meaning of life?", progress_callback=callback.StreamToStdout(callback_interval=2))

if __name__ == "__main__":
    main()
```

Quantization formats:
- 3-bit (19.662 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q3f16_1-MLC)
- 4-bit (24.466 GB): ["HF://junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC"](https://huggingface.co/junrushao/Mixtral-8x7B-Instruct-v0.1-q4f16_1-MLC)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants