diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md index 359ed4d5e1e8..1b7515498e44 100644 --- a/docs/source/en/gguf.md +++ b/docs/source/en/gguf.md @@ -46,16 +46,30 @@ The initial supported quantization types are decided according to the popular qu on the Hub. - F32 +- F16 +- BF16 +- Q4_0 +- Q4_1 +- Q5_0 +- Q5_1 +- Q8_0 - Q2_K - Q3_K -- Q4_0 - Q4_K - Q5_K - Q6_K -- Q8_0 +- IQ1_S +- IQ1_M +- IQ2_XXS +- IQ2_XS +- IQ2_S +- IQ3_XXS +- IQ3_S +- IQ4_XS +- IQ4_NL -We take example from the excellent [99991/pygguf](https://github.com/99991/pygguf) Python parser to dequantize the -weights. +> [!NOTE] +> To support gguf dequantization, `gguf>=0.10.0` installation is required. ### Supported model architectures diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 7da09be841e1..fe5b71b7d613 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -33,44 +33,6 @@ logger = logging.get_logger(__name__) -# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md -GGML_TYPES = { - "F32": 0, - "F16": 1, - "Q4_0": 2, - "Q8_0": 8, - "Q2_K": 10, - "Q3_K": 11, - "Q4_K": 12, - "Q5_K": 13, - "Q6_K": 14, -} - -# The Blocksizes are reported in bytes -# Check out: https://github.com/ggerganov/llama.cpp/blob/8a56075b07a8b571bf95a912ffdce4c928c2b414/gguf-py/gguf/constants.py#L801 -GGML_BLOCK_SIZES = { - "Q8_0": 2 + 32, # Q8_0 uses a blocksize of 32 (int8 tensors) + 2 bytes allocated for the scales - "Q4_K": 144, - # Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales - "Q4_0": 2 + 16, - "Q6_K": 210, - # See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9 - "Q2_K": 256 // 16 + 256 // 4 + 2 + 2, - "Q3_K": 256 // 8 + 256 // 4 + 12 + 2, - "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, -} - -# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md -DATA_TYPES = { - "uint32": 4, - "int32": 5, - "float32": 6, - "bool": 7, - "string": 8, - "array": 9, - "uint64": 10, -} - GGUF_TENSOR_MAPPING = { "llama": { "token_embd": "model.embed_tokens", @@ -217,303 +179,6 @@ def _gguf_parse_value(_value, data_type): return _value -def dequantize_q4_k(data, n_bytes: int): - # C implementation - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929 - # C struct definition - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116 - block_size = GGML_BLOCK_SIZES["Q4_K"] - num_blocks = n_bytes // block_size - - data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) - data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) - - # Casting to float32 because float16 is very slow on CPU - scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32) - scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32) - qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1) - qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32) - - # Dequantize scales and offsets (6 bits and 4 + 2 bits) - factors = scale_factors * np.concatenate( - [qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1 - ) - offsets = scale_offsets * np.concatenate( - [qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1 - ) - - # Interleave low and high quantized bits - qs2 = np.stack([qs2 & 0xF, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32) - # Dequantize final weights using scales and offsets - return factors * qs2 - offsets - - -def dequantize_q4_0(data, n_bytes: int): - # C implementation - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1086 - # C struct definition - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L11 - block_size = GGML_BLOCK_SIZES["Q4_0"] - num_blocks = n_bytes // block_size - - data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) - data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) - - # The scales are stored on the first 2 bytes and the rest corresponds to the quants - scales = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32) - # scales = np.nan_to_num(scales) - # the rest of the bytes corresponds to the quants - we discard the first two bytes - quants = data_u8[:, 2:] - - ql = (quants[:, :] & 0xF).astype(np.int8) - 8 - qr = (quants[:, :] >> 4).astype(np.int8) - 8 - - # Use hstack - quants = np.hstack([ql, qr]) - - return (scales * quants).astype(np.float32) - - -def dequantize_q6_k(data, n_bytes: int): - # C implementation - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275 - # C struct definition - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152 - block_size = GGML_BLOCK_SIZES["Q6_K"] - num_blocks = n_bytes // block_size - - data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) - data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) - data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size) - - scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32) - - # TODO use uint8 and cast later? - ql = data_u8[:, :128].astype(np.int16) - qh = data_u8[:, 128:192].astype(np.int16) - sc = data_i8[:, 192:208, np.newaxis].astype(np.float32) - - # Unpack bits, subtraction requires signed data type - q1 = (ql[:, :32] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32 - q2 = (ql[:, 32:64] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32 - q3 = (ql[:, :32] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32 - q4 = (ql[:, 32:64] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32 - q5 = (ql[:, 64:96] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32 - q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32 - q7 = (ql[:, 64:96] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32 - q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32 - - # Dequantize - return scales * np.concatenate( - [ - sc[:, 0] * q1[:, :16], - sc[:, 1] * q1[:, 16:], - sc[:, 2] * q2[:, :16], - sc[:, 3] * q2[:, 16:], - sc[:, 4] * q3[:, :16], - sc[:, 5] * q3[:, 16:], - sc[:, 6] * q4[:, :16], - sc[:, 7] * q4[:, 16:], - sc[:, 8] * q5[:, :16], - sc[:, 9] * q5[:, 16:], - sc[:, 10] * q6[:, :16], - sc[:, 11] * q6[:, 16:], - sc[:, 12] * q7[:, :16], - sc[:, 13] * q7[:, 16:], - sc[:, 14] * q8[:, :16], - sc[:, 15] * q8[:, 16:], - ], - axis=1, - ) - - -def dequantize_q8_0(data, n_bytes: int): - # C struct definition - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43 - block_size = GGML_BLOCK_SIZES["Q8_0"] - num_blocks = n_bytes // block_size - - scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32) - qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] - - return scales * qs - - -def dequantize_q2_k(data, n_bytes: int): - # C implementation - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547 - # C struct definition - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74 - num_blocks = n_bytes // GGML_BLOCK_SIZES["Q2_K"] - - data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"] // 2) - data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"]) - - dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) - d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32) - scales = data_u8[:, :16].reshape(num_blocks, 16, 1) - qs = data_u8[:, 16:80].reshape(num_blocks, 64) - - tmp = np.stack( - [ - qs[:, 00:16] >> 0, - qs[:, 16:32] >> 0, - qs[:, 00:16] >> 2, - qs[:, 16:32] >> 2, - qs[:, 00:16] >> 4, - qs[:, 16:32] >> 4, - qs[:, 00:16] >> 6, - qs[:, 16:32] >> 6, - qs[:, 32:48] >> 0, - qs[:, 48:64] >> 0, - qs[:, 32:48] >> 2, - qs[:, 48:64] >> 2, - qs[:, 32:48] >> 4, - qs[:, 48:64] >> 4, - qs[:, 32:48] >> 6, - qs[:, 48:64] >> 6, - ], - axis=1, - ) - - return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) - - -def dequantize_q3_k(data, n_bytes: int): - # C implementation - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42 - # C struct definition - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95 - num_blocks = n_bytes // GGML_BLOCK_SIZES["Q3_K"] - - data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"] // 2) - data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"]) - - d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) - bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little") - bits = 4 ^ (bits << 2) - qs = data_u8[:, 32 : 32 + 64].astype(np.int16) - a, b, c = data_u8[:, 96 : 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2) - scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8) - scales[:, 0] = (a & 15) | ((c & 3) << 4) - scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4) - scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4) - scales[:, 3] = (b >> 4) | ((c >> 6) << 4) - scales = scales.reshape(num_blocks, 16, 1).astype(np.int16) - - return ( - d - * (scales - 32) - * np.stack( - [ - (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]), - (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]), - (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]), - (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]), - (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]), - (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]), - (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]), - (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]), - (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]), - (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]), - (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]), - (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]), - (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]), - (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]), - (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]), - (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]), - ], - axis=1, - ) - ) - - -def dequantize_q5_k(data, n_bytes: int): - # C implementation - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129 - # C struct definition - # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138 - num_blocks = n_bytes // GGML_BLOCK_SIZES["Q5_K"] - - data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"] // 2) - data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"]) - - d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32) - dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32) - scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1) - qh = data_u8[:, 16 : 16 + 32].reshape(num_blocks, 32, 1) - qs = data_u8[:, 48 : 48 + 128].reshape(num_blocks, 4, 32) - - bits = np.unpackbits(qh, axis=-1, bitorder="little") - - qs_hi_4 = qs >> 4 - qs_lo_4 = qs & 15 - - scales_lo_6 = scales[:, :8] & 63 - scales_hi_6 = scales[:, :8] >> 6 - scales_lo_4 = scales[:, 8:] & 15 - scales_hi_4 = scales[:, 8:] >> 4 - - m1 = dmin * scales_lo_6[:, 4] - m2 = dmin * scales_lo_6[:, 5] - m3 = dmin * scales_lo_6[:, 6] - m4 = dmin * scales_lo_6[:, 7] - m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4)) - m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4)) - m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4)) - m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4)) - - d1 = d * scales_lo_6[:, 0] - d2 = d * scales_lo_6[:, 1] - d3 = d * scales_lo_6[:, 2] - d4 = d * scales_lo_6[:, 3] - d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4)) - d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4)) - d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4)) - d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4)) - - return np.concatenate( - [ - d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1, - d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2, - d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3, - d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4, - d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5, - d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6, - d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7, - d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, - ], - axis=1, - ) - - -def load_dequant_gguf_tensor(shape, ggml_type, data, n_bytes): - if ggml_type == GGML_TYPES["F32"]: - values = data - elif ggml_type == GGML_TYPES["F16"]: - values = data - elif ggml_type == GGML_TYPES["Q8_0"]: - values = dequantize_q8_0(data, n_bytes) - elif ggml_type == GGML_TYPES["Q4_0"]: - values = dequantize_q4_0(data, n_bytes) - elif ggml_type == GGML_TYPES["Q4_K"]: - values = dequantize_q4_k(data, n_bytes) - elif ggml_type == GGML_TYPES["Q6_K"]: - values = dequantize_q6_k(data, n_bytes) - elif ggml_type == GGML_TYPES["Q2_K"]: - values = dequantize_q2_k(data, n_bytes) - elif ggml_type == GGML_TYPES["Q3_K"]: - values = dequantize_q3_k(data, n_bytes) - elif ggml_type == GGML_TYPES["Q5_K"]: - values = dequantize_q5_k(data, n_bytes) - else: - raise NotImplementedError( - f"ggml_type {ggml_type} not implemented - please raise an issue on huggingface transformers: https://github.com/huggingface/transformers/issues/new/choose" - ) - - return values.reshape(shape[::-1]) - - class GGUFTokenizerSkeleton: def __init__(self, dict_): for k, v in dict_.items(): diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 52b1068e003f..e5fa0ff7b509 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -24,9 +24,9 @@ GGUF_TENSOR_MAPPING, GGUF_TOKENIZER_MAPPING, _gguf_parse_value, - load_dequant_gguf_tensor, ) from .utils import is_torch_available +from .utils.import_utils import is_gguf_available from .utils.logging import get_logger @@ -71,14 +71,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): Whether to read the tensors from the file and return them. Not doing so is faster and only loads the metadata in memory. """ - try: - from gguf import GGUFReader - except (ImportError, ModuleNotFoundError): + if is_gguf_available() and is_torch_available(): + from gguf import GGUFReader, dequantize + else: logger.error( - "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF to be installed. Please see " + "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see " "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions." ) - raise + raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.") reader = GGUFReader(gguf_checkpoint_path) fields = reader.fields @@ -154,12 +154,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): tensor_name_mapping, GGUF_TO_TRANSFORMERS_MAPPING["tensors"][tensor_name_mapping] ) - shape = tensor.shape name = tensor.name - weights = load_dequant_gguf_tensor( - shape=shape, ggml_type=tensor.tensor_type, data=tensor.data, n_bytes=int(tensor.n_bytes) - ) + weights = dequantize(tensor.data, tensor.tensor_type) if architecture == "llama" and (".attn_k." in name or ".attn_q." in name): num_heads = parsed_parameters["config"]["num_attention_heads"] diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 3d30c9ff6479..b6dfa85b1d20 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -53,6 +53,7 @@ from .integrations.deepspeed import is_deepspeed_available from .utils import ( ACCELERATE_MIN_VERSION, + GGUF_MIN_VERSION, is_accelerate_available, is_apex_available, is_aqlm_available, @@ -407,11 +408,13 @@ def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION): )(test_case) -def require_gguf(test_case): +def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION): """ Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed. """ - return unittest.skipUnless(is_gguf_available(), "test requires gguf")(test_case) + return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")( + test_case + ) def require_fsdp(test_case, min_version: str = "1.12.0"): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 56f594da15f1..6a2ef8437ccf 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -99,6 +99,7 @@ ACCELERATE_MIN_VERSION, ENV_VARS_TRUE_AND_AUTO_VALUES, ENV_VARS_TRUE_VALUES, + GGUF_MIN_VERSION, TORCH_FX_REQUIRED_VERSION, USE_JAX, USE_TF, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 0c16cac0f071..416599203d9c 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -89,6 +89,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ ACCELERATE_MIN_VERSION = "0.26.0" FSDP_MIN_VERSION = "1.12.0" +GGUF_MIN_VERSION = "0.10.0" XLA_FSDPV2_MIN_VERSION = "2.2.0" @@ -155,7 +156,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _scipy_available = _is_package_available("scipy") _sentencepiece_available = _is_package_available("sentencepiece") _is_seqio_available = _is_package_available("seqio") -_is_gguf_available = _is_package_available("gguf") +_is_gguf_available, _gguf_version = _is_package_available("gguf", return_version=True) _sklearn_available = importlib.util.find_spec("sklearn") is not None if _sklearn_available: try: @@ -913,8 +914,8 @@ def is_seqio_available(): return _is_seqio_available -def is_gguf_available(): - return _is_gguf_available +def is_gguf_available(min_version: str = GGUF_MIN_VERSION): + return _is_gguf_available and version.parse(_gguf_version) >= version.parse(min_version) def is_protobuf_available(): diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index e42900a1d51b..c81df1910eb6 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -30,18 +30,32 @@ class GgufIntegrationTests(unittest.TestCase): original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" + imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF" mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF" llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF" tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF" + # standard quants q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" - q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + q5_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_0.gguf" + q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf" + # k-quants q2_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf" q3_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf" + q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf" q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf" - q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf" + # imatrix + iq1_m_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_M.gguf" + iq1_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_S.gguf" + iq2_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_S.gguf" + iq2_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_XS.gguf" + iq2_xxs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_XXS.gguf" + iq3_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ3_S.gguf" + iq3_xxs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ3_XXS.gguf" + iq4_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf" + iq4_nl_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_NL.gguf" q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf" q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf" @@ -87,6 +101,16 @@ def test_q3_k(self): EXPECTED_TEXT = "Hello, World!\n\n```\n<|user" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_q5_0(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_0_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_0_gguf_model_id).to(torch_device) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, World!\n\n5. Use a library" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_q5_k(self): tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id) model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id).to(torch_device) @@ -151,6 +175,114 @@ def test_q8_0(self): EXPECTED_TEXT = "Hello, World!\n\n5. Use a library" self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_iq1_s(self): + tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_s_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_s_gguf_model_id).to( + torch_device + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, I'm a friend of mine, I" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_iq1_m(self): + tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_m_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_m_gguf_model_id).to( + torch_device + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, I am interested in purching a copy of" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_iq2_s(self): + tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_s_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_s_gguf_model_id).to( + torch_device + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello World!\n\n```\n<|user|" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_iq2_xs(self): + tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xs_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xs_gguf_model_id).to( + torch_device + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello World!\n\n```\n<|user|" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_iq2_xxs(self): + tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xxs_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xxs_gguf_model_id).to( + torch_device + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, I'm a software engineer. I'" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_iq3_s(self): + tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_s_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_s_gguf_model_id).to( + torch_device + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_iq3_xxs(self): + tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_xxs_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_xxs_gguf_model_id).to( + torch_device + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, I am interested in your product. Can you" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_iq4_xs(self): + tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_xs_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_xs_gguf_model_id).to( + torch_device + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, world!\n\n5. Using a loop" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + + def test_iq4_nl(self): + tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_nl_gguf_model_id) + model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_nl_gguf_model_id).to( + torch_device + ) + + text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) + out = model.generate(**text, max_new_tokens=10) + + EXPECTED_TEXT = "Hello, world!\n\n5. Using a loop" + self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) + def test_f16(self): tokenizer = AutoTokenizer.from_pretrained(self.tinyllama_model_id, gguf_file=self.f16_tinyllama_model_id) model = AutoModelForCausalLM.from_pretrained(