Skip to content

Commit

Permalink
Add Python API for Weight Conversion (mlc-ai#1182)
Browse files Browse the repository at this point in the history
This PR primarily does a major refactoring to introduce Python API that
is consistent with the CLI API. Besides, it includes the following
fixes and enhancements:

- More info provided to `isort` for better formatting in `pyproject.toml`;
- Print out the default value of all arguments in argparse command line;
- Ensure `--device` is always available locally when doing weight
  conversion;
- Add argument echoing in weight conversion to be consistent with its
  counterpart in compilation;
- Add a consistency checker to make sure the shapes/dtypes of all
  tensors from weight conversion is consistent with compilation;
- Echo the total size of parameters;
- Better logging of each parameter's shape and dtype, and either or not
  its quantized;
- More structure robustification, renaming `parameter/` to `loader/` to
  be more explicit about its intention;
- Inline and remove `ParamQuantizer` into the loader to improve logging
  and the logic flow;
- Always add instructions "Use `--xxx` to override" for any options that
  are auto detected to be more informative to end users;
- Fix wrong shape calculation when quantizing `nn.Embedding`;
- Fix wrong dtype calculation in group quantization when the input dtype
  is different from model dtype (e.g. "float32" in torch, but the model
  dtype in quantization is fp16 in `q4f16_1`);
- Fix inconsistent param names in layers such as `GroupQuantizeLinear`;
- Fix dtype inconsistency when a parameter is not quantized;
- Fix existing unittests.
  • Loading branch information
junrushao authored Nov 4, 2023
1 parent 6ae02dd commit 4716704
Show file tree
Hide file tree
Showing 27 changed files with 481 additions and 289 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
[tool.isort]
profile = "black"
src_paths = ["python/mlc_chat"]
known_third_party = ["numpy", "tvm", "tqdm", "torch", "transformers"]

[tool.black]
line-length = 100
Expand Down
17 changes: 11 additions & 6 deletions python/mlc_chat/cli/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,16 @@ def _parse_output(path: Union[str, Path]) -> Path:
default="auto",
choices=["auto"] + list(MODELS.keys()),
help="Model architecture, for example, llama. If not set, it is inferred "
"from the config.json file.",
"from the config.json file. "
"(default: %(default)s)",
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="The GPU device to compile the model to. If not set, it is inferred from locally "
"available GPUs.",
"available GPUs. "
"(default: %(default)s)",
)
parser.add_argument(
"--host",
Expand All @@ -81,25 +83,28 @@ def _parse_output(path: Union[str, Path]) -> Path:
"x86-64",
],
help="The host CPU ISA to compile the model to. If not set, it is inferred from the "
"local CPU.",
"local CPU. "
"(default: %(default)s)",
)
parser.add_argument(
"--opt",
type=OptimizationFlags.from_str,
default="",
default="O2",
help="Optimization flags. MLC LLM maintains a predefined set of optimization flags, "
"denoted as O0, O1, O2, O3, where O0 means no optimization, O2 means majority of them, "
"and O3 represents extreme optimization that could potentially break the system. "
"Meanwhile, optimization flags could be explicitly specified via details knobs, e.g. "
'--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0"',
'--opt="cutlass_attn=1;cutlass_norm=0;cublas_gemm=0;cudagraph=0. '
"(default: %(default)s)",
)
parser.add_argument(
"--prefix-symbols",
type=str,
default="",
help='Adding a prefix to all symbols exported. Similar to "objcopy --prefix-symbols". '
"This is useful when compiling multiple models into a single library to avoid symbol "
"conflicts. Differet from objcopy, this takes no effect for shared library.",
"conflicts. Differet from objcopy, this takes no effect for shared library. "
'(default: "")',
)
parser.add_argument(
"--output",
Expand Down
78 changes: 24 additions & 54 deletions python/mlc_chat/cli/convert_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@
from pathlib import Path
from typing import Union

import tvm
from mlc_chat.compiler import MODELS, QUANTIZATION
from mlc_chat.compiler.parameter import HuggingFaceLoader
from mlc_chat.support import tqdm
from tvm.contrib import tvmjs
from mlc_chat.compiler import MODELS, QUANTIZATION, convert_weight

from ..support.auto_config import detect_config, detect_model_type
from ..support.auto_target import detect_target_and_host
from ..support.auto_target import detect_device
from ..support.auto_weight import detect_weight

logging.basicConfig(
Expand Down Expand Up @@ -57,17 +53,17 @@ def _parse_output(path: Union[str, Path]) -> Path:
parser.add_argument(
"--source",
type=str,
required=False,
default="auto",
help="The path to original model weight, infer from `config` if missing",
help="The path to original model weight, infer from `config` if missing. "
"(default: %(default)s)",
)
parser.add_argument(
"--source-format",
type=str,
required=False,
choices=["auto", "huggingface-torch", "huggingface-safetensor"],
default="auto",
help="The format of source model weight, infer from `config` if missing",
help="The format of source model weight, infer from `config` if missing. "
"(default: %(default)s)",
)
parser.add_argument(
"--quantization",
Expand All @@ -82,14 +78,16 @@ def _parse_output(path: Union[str, Path]) -> Path:
default="auto",
choices=["auto"] + list(MODELS.keys()),
help="Model architecture, for example, llama. If not set, it is inferred "
"from the config.json file.",
"from the config.json file. "
"(default: %(default)s)",
)
parser.add_argument(
"--device",
type=str,
default="auto",
help="The device used to do quantization, \
for example `auto` / `cuda:0` / `cuda --arch sm86`",
type=detect_device,
help="The device used to do quantization, for example, / `cuda:0`. "
"Detect from local environment if not specified. "
"(default: %(default)s)",
)
parser.add_argument(
"--output",
Expand All @@ -100,49 +98,21 @@ def _parse_output(path: Union[str, Path]) -> Path:
"will contain `params_shard_*.bin` and `ndarray-cache.json`.",
)

# parse arguments
parsed = parser.parse_args()
parsed.source = _parse_source(parsed.source, parsed.config)
parsed.params, parsed.source_format = detect_weight(
parsed.source, parsed.config, weight_format=parsed.source_format
parsed.source, parsed.source_format = detect_weight(
weight_path=_parse_source(parsed.source, parsed.config),
config_json_path=parsed.config,
weight_format=parsed.source_format,
)
model = detect_model_type(parsed.model_type, parsed.config)

# detect quantization target
quantization_target, _ = detect_target_and_host(parsed.device)
if parsed.device != "auto":
device = tvm.runtime.device(parsed.device.split(" ")[0])
else:
if quantization_target.kind.name == "cuda":
device = tvm.cuda(0)
else:
device = tvm.cpu(0)

# model config & quantization config
model_config = model.config.from_file(parsed.config)
quantization_config = QUANTIZATION[parsed.quantization]
_, quantize_map = model.quantize[quantization_config.kind](model_config, quantization_config)

# loader setup
if parsed.source_format in ("huggingface-torch", "huggingface-safetensor"):
loader = HuggingFaceLoader(
path=parsed.params,
extern_param_map=model.source[parsed.source_format](model_config, None),
quantize_param_map=quantize_map,
)
else:
raise ValueError(f"Unsupported loader source format: {parsed.source_format}")

# load and quantize
with quantization_target, tqdm.redirect():
param_dict = dict(loader.load(device=device))

# dump to output directory
tvmjs.dump_ndarray_cache(
param_dict,
f"{parsed.output}/params",
meta_data={"ParamSize": len(param_dict)},
encode_format="raw",
convert_weight(
config=parsed.config,
quantization=QUANTIZATION[parsed.quantization],
model=model,
device=parsed.device,
source=parsed.source,
source_format=parsed.source_format,
output=parsed.output,
)


Expand Down
3 changes: 2 additions & 1 deletion python/mlc_chat/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""
from . import compiler_pass
from .compile import CompileArgs, compile # pylint: disable=redefined-builtin
from .convert_weight import convert_weight
from .flags_optimization import OptimizationFlags
from .loader import ExternMapping, HuggingFaceLoader, QuantizeMapping
from .model import MODEL_PRESETS, MODELS, Model
from .parameter import ExternMapping, HuggingFaceLoader, QuantizeMapping
from .quantization import QUANTIZATION
124 changes: 124 additions & 0 deletions python/mlc_chat/compiler/convert_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Python entrypoint of weight conversion."""
import dataclasses
import logging
import math
from io import StringIO
from pathlib import Path

import numpy as np
from tvm.contrib import tvmjs
from tvm.runtime import Device, NDArray
from tvm.runtime import cpu as cpu_device
from tvm.target import Target

from mlc_chat.support import tqdm

from ..support.style import bold, green
from .loader import LOADER
from .model import Model
from .quantization import Quantization

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class ConversionArgs: # pylint: disable=too-many-instance-attributes
"""Arguments to MLC LLM's weight conversation and quantization flow."""

config: Path
quantization: Quantization
model: Model
device: Device
source: Path
source_format: str
output: Path


def _echo_args(args: ConversionArgs) -> None:
def _device_to_str(device: Device) -> str:
return f"{Device.MASK2STR[device.device_type]}:{device.device_id}"

out = StringIO()
print(f"{bold('Weight conversion with arguments:')}", file=out)
print(f" {bold('--config'):<25} {args.config}", file=out)
print(f" {bold('--quantization'):<25} {args.quantization}", file=out)
print(f" {bold('--model-type'):<25} {args.model.name}", file=out)
print(f" {bold('--device'):<25} {_device_to_str(args.device)}", file=out)
print(f" {bold('--source'):<25} {args.source}", file=out)
print(f" {bold('--source-format'):<25} {args.source_format}", file=out)
print(f" {bold('--output'):<25} {args.output}", file=out)
print(out.getvalue().rstrip())


def _convert_args(args: ConversionArgs) -> None: # pylint: disable=too-many-locals
# model config & quantization config
model_config = args.model.config.from_file(args.config)
model, quantize_map = args.model.quantize[args.quantization.kind](
model_config, args.quantization
)
_, _named_params = model.export_tvm(spec=model.get_default_spec()) # type: ignore[attr-defined]
named_params = dict(_named_params)

def _check_param(name: str, param: NDArray):
nonlocal named_params
if name not in named_params:
raise ValueError(f"Parameter not found in model: {name}")
if name in param_dict:
raise ValueError(f"Duplication: Parameter {name} already computed")
expect_shape = tuple(int(x) for x in named_params[name].shape)
expect_dtype = named_params[name].dtype
actual_shape = tuple(int(x) for x in param.shape)
actual_dtype = param.dtype
if actual_shape != expect_shape:
raise ValueError(
f"Parameter {name} has shape {param.shape}, but expected {expect_shape}"
)
if actual_dtype != expect_dtype:
raise ValueError(
f"Parameter {name} has dtype {param.dtype}, but expected {expect_dtype}"
)
del named_params[name]

# load and quantize
param_dict = {}
total_bytes = 0.0
total_params = 0
with Target.from_device(args.device), tqdm.redirect():
for name, param in LOADER[args.source_format](
path=args.source,
extern_param_map=args.model.source[args.source_format](model_config, args.quantization),
quantize_param_map=quantize_map,
).load(device=args.device):
_check_param(name, param)
param = param.copyto(cpu_device())
param_dict[name] = param
total_bytes += math.prod(param.shape) * np.dtype(param.dtype).itemsize
total_params += math.prod(param.shape)
if named_params:
raise ValueError(f"Parameter not found in source: {', '.join(named_params.keys())}")
# dump to output directory
tvmjs.dump_ndarray_cache(
param_dict,
str(args.output),
meta_data={"ParamSize": len(param_dict)},
encode_format="raw",
)
logger.info("%s to %s", green("Saved"), bold(str(args.output)))
logger.info("%s: %.3f GB", green("Total parameter size"), total_bytes / (1024**3))
logger.info("%s: %d", green("Total number of parameter tensors"), len(param_dict))
logger.info(f"%s: {total_params:,}", green("Total number of parameters"))


def convert_weight( # pylint: disable=too-many-arguments
config: Path,
quantization: Quantization,
model: Model,
device: Device,
source: Path,
source_format: str,
output: Path,
):
"""MLC LLM's weight conversation and quantization flow."""
args = ConversionArgs(config, quantization, model, device, source, source_format, output)
_echo_args(args)
_convert_args(args)
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
parameters and parameters in MLC-defined models.
"""
from .huggingface_loader import HuggingFaceLoader
from .loader import LOADER, Loader
from .mapping import ExternMapping, QuantizeMapping
Loading

0 comments on commit 4716704

Please sign in to comment.