-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathExport.py
43 lines (35 loc) · 1.49 KB
/
Export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import fire
import sys
import os
import json
from pathlib import Path
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../..")
import mixtral.modeling.Loader as Loader
from ModelParams import ModelParams
def main(
ckpt_dir: str,
export_path: str,
friendly_gqa: bool = False, # done gqa by repeating key and value by key_value_cache op
fused_qkv: bool = True, # fuse qkv linear
fused_kvcache: bool = True, # fuse key_value_cache and multi_head_attention
fused_ffn_glu: bool = True, # fuse feed forward gate linear unit
auto_causal: bool = True, # causal mask is auto done by attention op, no need to pass additional mask to the model
quantized_cache: bool = True, # 8bit kv cache quantization
cache_layout: int = 0, # change kv cache layout for hardware performance friendly
cache_mode: int = 0, # change kv cache indexing mode for memory management friendly, only affected when dynamic_batching == True
dynamic_batching: bool = True, # use dynamic batching scheduling
):
with open(Path(ckpt_dir) / "opmx_params.json", "r") as f:
params = json.loads(f.read())
params: ModelParams = ModelParams(**params)
generator = Loader.load(
ckpt_dir, params, friendly_gqa,
fused_qkv, fused_kvcache, fused_ffn_glu,
auto_causal, quantized_cache, cache_layout,
cache_mode, dynamic_batching,
True, False, False, True,
0
)
generator.export(export_path)
if __name__ == "__main__":
fire.Fire(main)