From a5025a9b8083e126754112521cad75140a673d8e Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Tue, 3 Sep 2024 18:58:14 +0800
Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20Support=20dequantization=20for?=
 =?UTF-8?q?=20most=20GGML=20types=20(#32625)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* use gguf internal dequantize

* add Q5_0 test

* add iq1 test

* add remained test

* remove duplicated test

* update docs

* add gguf version limit

* make style

* update gguf import catch

* revert vocab_size patch

* make style

* use GGUF_MIN_VERSION everywhere
---
 docs/source/en/gguf.md                        |  22 +-
 src/transformers/integrations/ggml.py         | 335 ------------------
 .../modeling_gguf_pytorch_utils.py            |  17 +-
 src/transformers/testing_utils.py             |   7 +-
 src/transformers/utils/__init__.py            |   1 +
 src/transformers/utils/import_utils.py        |   7 +-
 tests/quantization/ggml/test_ggml.py          | 136 ++++++-
 7 files changed, 169 insertions(+), 356 deletions(-)

diff --git a/docs/source/en/gguf.md b/docs/source/en/gguf.md
index 359ed4d5e1e8..1b7515498e44 100644
--- a/docs/source/en/gguf.md
+++ b/docs/source/en/gguf.md
@@ -46,16 +46,30 @@ The initial supported quantization types are decided according to the popular qu
 on the Hub.
 
 - F32
+- F16
+- BF16
+- Q4_0
+- Q4_1
+- Q5_0
+- Q5_1
+- Q8_0
 - Q2_K
 - Q3_K
-- Q4_0
 - Q4_K
 - Q5_K
 - Q6_K
-- Q8_0
+- IQ1_S
+- IQ1_M
+- IQ2_XXS
+- IQ2_XS
+- IQ2_S
+- IQ3_XXS
+- IQ3_S
+- IQ4_XS
+- IQ4_NL
 
-We take example from the excellent [99991/pygguf](https://github.com/99991/pygguf) Python parser to dequantize the 
-weights.
+> [!NOTE]
+> To support gguf dequantization, `gguf>=0.10.0` installation is required.
 
 ### Supported model architectures
 
diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py
index 7da09be841e1..fe5b71b7d613 100644
--- a/src/transformers/integrations/ggml.py
+++ b/src/transformers/integrations/ggml.py
@@ -33,44 +33,6 @@
 logger = logging.get_logger(__name__)
 
 
-# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
-GGML_TYPES = {
-    "F32": 0,
-    "F16": 1,
-    "Q4_0": 2,
-    "Q8_0": 8,
-    "Q2_K": 10,
-    "Q3_K": 11,
-    "Q4_K": 12,
-    "Q5_K": 13,
-    "Q6_K": 14,
-}
-
-# The Blocksizes are reported in bytes
-# Check out: https://github.com/ggerganov/llama.cpp/blob/8a56075b07a8b571bf95a912ffdce4c928c2b414/gguf-py/gguf/constants.py#L801
-GGML_BLOCK_SIZES = {
-    "Q8_0": 2 + 32,  # Q8_0 uses a blocksize of 32 (int8 tensors) + 2 bytes allocated for the scales
-    "Q4_K": 144,
-    # Q4_0 uses a blocksize of 32 but the 4-bit tensors are packed into 8-bit tensors + 2 bytes for the scales
-    "Q4_0": 2 + 16,
-    "Q6_K": 210,
-    # See: https://github.com/99991/pygguf/commit/a417edbfc029a1bc270f984a694f9128c5afa8b9
-    "Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
-    "Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
-    "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
-}
-
-# Listed here: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
-DATA_TYPES = {
-    "uint32": 4,
-    "int32": 5,
-    "float32": 6,
-    "bool": 7,
-    "string": 8,
-    "array": 9,
-    "uint64": 10,
-}
-
 GGUF_TENSOR_MAPPING = {
     "llama": {
         "token_embd": "model.embed_tokens",
@@ -217,303 +179,6 @@ def _gguf_parse_value(_value, data_type):
     return _value
 
 
-def dequantize_q4_k(data, n_bytes: int):
-    # C implementation
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929
-    # C struct definition
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
-    block_size = GGML_BLOCK_SIZES["Q4_K"]
-    num_blocks = n_bytes // block_size
-
-    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
-    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
-
-    # Casting to float32 because float16 is very slow on CPU
-    scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
-    scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
-    qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
-    qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
-
-    # Dequantize scales and offsets (6 bits and 4 + 2 bits)
-    factors = scale_factors * np.concatenate(
-        [qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1
-    )
-    offsets = scale_offsets * np.concatenate(
-        [qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1
-    )
-
-    # Interleave low and high quantized bits
-    qs2 = np.stack([qs2 & 0xF, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
-    # Dequantize final weights using scales and offsets
-    return factors * qs2 - offsets
-
-
-def dequantize_q4_0(data, n_bytes: int):
-    # C implementation
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1086
-    # C struct definition
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L11
-    block_size = GGML_BLOCK_SIZES["Q4_0"]
-    num_blocks = n_bytes // block_size
-
-    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
-    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
-
-    # The scales are stored on the first 2 bytes and the rest corresponds to the quants
-    scales = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
-    # scales = np.nan_to_num(scales)
-    # the rest of the bytes corresponds to the quants - we discard the first two bytes
-    quants = data_u8[:, 2:]
-
-    ql = (quants[:, :] & 0xF).astype(np.int8) - 8
-    qr = (quants[:, :] >> 4).astype(np.int8) - 8
-
-    # Use hstack
-    quants = np.hstack([ql, qr])
-
-    return (scales * quants).astype(np.float32)
-
-
-def dequantize_q6_k(data, n_bytes: int):
-    # C implementation
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275
-    # C struct definition
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152
-    block_size = GGML_BLOCK_SIZES["Q6_K"]
-    num_blocks = n_bytes // block_size
-
-    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
-    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
-    data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
-
-    scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
-
-    # TODO use uint8 and cast later?
-    ql = data_u8[:, :128].astype(np.int16)
-    qh = data_u8[:, 128:192].astype(np.int16)
-    sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
-
-    # Unpack bits, subtraction requires signed data type
-    q1 = (ql[:, :32] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
-    q2 = (ql[:, 32:64] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
-    q3 = (ql[:, :32] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
-    q4 = (ql[:, 32:64] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
-    q5 = (ql[:, 64:96] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
-    q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
-    q7 = (ql[:, 64:96] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
-    q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
-
-    # Dequantize
-    return scales * np.concatenate(
-        [
-            sc[:, 0] * q1[:, :16],
-            sc[:, 1] * q1[:, 16:],
-            sc[:, 2] * q2[:, :16],
-            sc[:, 3] * q2[:, 16:],
-            sc[:, 4] * q3[:, :16],
-            sc[:, 5] * q3[:, 16:],
-            sc[:, 6] * q4[:, :16],
-            sc[:, 7] * q4[:, 16:],
-            sc[:, 8] * q5[:, :16],
-            sc[:, 9] * q5[:, 16:],
-            sc[:, 10] * q6[:, :16],
-            sc[:, 11] * q6[:, 16:],
-            sc[:, 12] * q7[:, :16],
-            sc[:, 13] * q7[:, 16:],
-            sc[:, 14] * q8[:, :16],
-            sc[:, 15] * q8[:, 16:],
-        ],
-        axis=1,
-    )
-
-
-def dequantize_q8_0(data, n_bytes: int):
-    # C struct definition
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
-    block_size = GGML_BLOCK_SIZES["Q8_0"]
-    num_blocks = n_bytes // block_size
-
-    scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
-    qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
-
-    return scales * qs
-
-
-def dequantize_q2_k(data, n_bytes: int):
-    # C implementation
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547
-    # C struct definition
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74
-    num_blocks = n_bytes // GGML_BLOCK_SIZES["Q2_K"]
-
-    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"] // 2)
-    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q2_K"])
-
-    dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
-    d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
-    scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
-    qs = data_u8[:, 16:80].reshape(num_blocks, 64)
-
-    tmp = np.stack(
-        [
-            qs[:, 00:16] >> 0,
-            qs[:, 16:32] >> 0,
-            qs[:, 00:16] >> 2,
-            qs[:, 16:32] >> 2,
-            qs[:, 00:16] >> 4,
-            qs[:, 16:32] >> 4,
-            qs[:, 00:16] >> 6,
-            qs[:, 16:32] >> 6,
-            qs[:, 32:48] >> 0,
-            qs[:, 48:64] >> 0,
-            qs[:, 32:48] >> 2,
-            qs[:, 48:64] >> 2,
-            qs[:, 32:48] >> 4,
-            qs[:, 48:64] >> 4,
-            qs[:, 32:48] >> 6,
-            qs[:, 48:64] >> 6,
-        ],
-        axis=1,
-    )
-
-    return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
-
-
-def dequantize_q3_k(data, n_bytes: int):
-    # C implementation
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42
-    # C struct definition
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95
-    num_blocks = n_bytes // GGML_BLOCK_SIZES["Q3_K"]
-
-    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"] // 2)
-    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q3_K"])
-
-    d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
-    bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
-    bits = 4 ^ (bits << 2)
-    qs = data_u8[:, 32 : 32 + 64].astype(np.int16)
-    a, b, c = data_u8[:, 96 : 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
-    scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
-    scales[:, 0] = (a & 15) | ((c & 3) << 4)
-    scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
-    scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
-    scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
-    scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
-
-    return (
-        d
-        * (scales - 32)
-        * np.stack(
-            [
-                (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
-                (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
-                (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
-                (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
-                (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
-                (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
-                (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
-                (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
-                (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
-                (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
-                (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
-                (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
-                (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
-                (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
-                (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
-                (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]),
-            ],
-            axis=1,
-        )
-    )
-
-
-def dequantize_q5_k(data, n_bytes: int):
-    # C implementation
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129
-    # C struct definition
-    # https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138
-    num_blocks = n_bytes // GGML_BLOCK_SIZES["Q5_K"]
-
-    data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"] // 2)
-    data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, GGML_BLOCK_SIZES["Q5_K"])
-
-    d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
-    dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
-    scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
-    qh = data_u8[:, 16 : 16 + 32].reshape(num_blocks, 32, 1)
-    qs = data_u8[:, 48 : 48 + 128].reshape(num_blocks, 4, 32)
-
-    bits = np.unpackbits(qh, axis=-1, bitorder="little")
-
-    qs_hi_4 = qs >> 4
-    qs_lo_4 = qs & 15
-
-    scales_lo_6 = scales[:, :8] & 63
-    scales_hi_6 = scales[:, :8] >> 6
-    scales_lo_4 = scales[:, 8:] & 15
-    scales_hi_4 = scales[:, 8:] >> 4
-
-    m1 = dmin * scales_lo_6[:, 4]
-    m2 = dmin * scales_lo_6[:, 5]
-    m3 = dmin * scales_lo_6[:, 6]
-    m4 = dmin * scales_lo_6[:, 7]
-    m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
-    m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
-    m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
-    m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
-
-    d1 = d * scales_lo_6[:, 0]
-    d2 = d * scales_lo_6[:, 1]
-    d3 = d * scales_lo_6[:, 2]
-    d4 = d * scales_lo_6[:, 3]
-    d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
-    d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
-    d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
-    d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
-
-    return np.concatenate(
-        [
-            d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
-            d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
-            d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
-            d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
-            d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
-            d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
-            d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
-            d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
-        ],
-        axis=1,
-    )
-
-
-def load_dequant_gguf_tensor(shape, ggml_type, data, n_bytes):
-    if ggml_type == GGML_TYPES["F32"]:
-        values = data
-    elif ggml_type == GGML_TYPES["F16"]:
-        values = data
-    elif ggml_type == GGML_TYPES["Q8_0"]:
-        values = dequantize_q8_0(data, n_bytes)
-    elif ggml_type == GGML_TYPES["Q4_0"]:
-        values = dequantize_q4_0(data, n_bytes)
-    elif ggml_type == GGML_TYPES["Q4_K"]:
-        values = dequantize_q4_k(data, n_bytes)
-    elif ggml_type == GGML_TYPES["Q6_K"]:
-        values = dequantize_q6_k(data, n_bytes)
-    elif ggml_type == GGML_TYPES["Q2_K"]:
-        values = dequantize_q2_k(data, n_bytes)
-    elif ggml_type == GGML_TYPES["Q3_K"]:
-        values = dequantize_q3_k(data, n_bytes)
-    elif ggml_type == GGML_TYPES["Q5_K"]:
-        values = dequantize_q5_k(data, n_bytes)
-    else:
-        raise NotImplementedError(
-            f"ggml_type {ggml_type} not implemented - please raise an issue on huggingface transformers: https://github.com/huggingface/transformers/issues/new/choose"
-        )
-
-    return values.reshape(shape[::-1])
-
-
 class GGUFTokenizerSkeleton:
     def __init__(self, dict_):
         for k, v in dict_.items():
diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py
index 52b1068e003f..e5fa0ff7b509 100644
--- a/src/transformers/modeling_gguf_pytorch_utils.py
+++ b/src/transformers/modeling_gguf_pytorch_utils.py
@@ -24,9 +24,9 @@
     GGUF_TENSOR_MAPPING,
     GGUF_TOKENIZER_MAPPING,
     _gguf_parse_value,
-    load_dequant_gguf_tensor,
 )
 from .utils import is_torch_available
+from .utils.import_utils import is_gguf_available
 from .utils.logging import get_logger
 
 
@@ -71,14 +71,14 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
             Whether to read the tensors from the file and return them. Not doing so is faster
             and only loads the metadata in memory.
     """
-    try:
-        from gguf import GGUFReader
-    except (ImportError, ModuleNotFoundError):
+    if is_gguf_available() and is_torch_available():
+        from gguf import GGUFReader, dequantize
+    else:
         logger.error(
-            "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF to be installed. Please see "
+            "Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
             "https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
         )
-        raise
+        raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
 
     reader = GGUFReader(gguf_checkpoint_path)
     fields = reader.fields
@@ -154,12 +154,9 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
                         tensor_name_mapping, GGUF_TO_TRANSFORMERS_MAPPING["tensors"][tensor_name_mapping]
                     )
 
-            shape = tensor.shape
             name = tensor.name
 
-            weights = load_dequant_gguf_tensor(
-                shape=shape, ggml_type=tensor.tensor_type, data=tensor.data, n_bytes=int(tensor.n_bytes)
-            )
+            weights = dequantize(tensor.data, tensor.tensor_type)
 
             if architecture == "llama" and (".attn_k." in name or ".attn_q." in name):
                 num_heads = parsed_parameters["config"]["num_attention_heads"]
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 3d30c9ff6479..b6dfa85b1d20 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -53,6 +53,7 @@
 from .integrations.deepspeed import is_deepspeed_available
 from .utils import (
     ACCELERATE_MIN_VERSION,
+    GGUF_MIN_VERSION,
     is_accelerate_available,
     is_apex_available,
     is_aqlm_available,
@@ -407,11 +408,13 @@ def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
     )(test_case)
 
 
-def require_gguf(test_case):
+def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
     """
     Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
     """
-    return unittest.skipUnless(is_gguf_available(), "test requires gguf")(test_case)
+    return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")(
+        test_case
+    )
 
 
 def require_fsdp(test_case, min_version: str = "1.12.0"):
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index b1a1bb56cbd8..27d102a9fd12 100755
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -99,6 +99,7 @@
     ACCELERATE_MIN_VERSION,
     ENV_VARS_TRUE_AND_AUTO_VALUES,
     ENV_VARS_TRUE_VALUES,
+    GGUF_MIN_VERSION,
     TORCH_FX_REQUIRED_VERSION,
     USE_JAX,
     USE_TF,
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index c4bb1a64eb63..8ae133d0ffe0 100755
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -89,6 +89,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
 
 ACCELERATE_MIN_VERSION = "0.26.0"
 FSDP_MIN_VERSION = "1.12.0"
+GGUF_MIN_VERSION = "0.10.0"
 XLA_FSDPV2_MIN_VERSION = "2.2.0"
 
 
@@ -156,7 +157,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
 _scipy_available = _is_package_available("scipy")
 _sentencepiece_available = _is_package_available("sentencepiece")
 _is_seqio_available = _is_package_available("seqio")
-_is_gguf_available = _is_package_available("gguf")
+_is_gguf_available, _gguf_version = _is_package_available("gguf", return_version=True)
 _sklearn_available = importlib.util.find_spec("sklearn") is not None
 if _sklearn_available:
     try:
@@ -914,8 +915,8 @@ def is_seqio_available():
     return _is_seqio_available
 
 
-def is_gguf_available():
-    return _is_gguf_available
+def is_gguf_available(min_version: str = GGUF_MIN_VERSION):
+    return _is_gguf_available and version.parse(_gguf_version) >= version.parse(min_version)
 
 
 def is_protobuf_available():
diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py
index e42900a1d51b..c81df1910eb6 100644
--- a/tests/quantization/ggml/test_ggml.py
+++ b/tests/quantization/ggml/test_ggml.py
@@ -30,18 +30,32 @@
 class GgufIntegrationTests(unittest.TestCase):
     original_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
     model_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
+    imatrix_model_id = "duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF"
     mistral_model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
     qwen2_model_id = "Qwen/Qwen1.5-0.5B-Chat-GGUF"
     llama3_model_id = "NousResearch/Meta-Llama-3-8B-GGUF"
     tinyllama_model_id = "PenutChen/TinyLlama-1.1B-Chat-v1.0-GGUF"
 
+    # standard quants
     q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
-    q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
+    q5_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_0.gguf"
+    q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
+    # k-quants
     q2_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q2_K.gguf"
     q3_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q3_K_L.gguf"
+    q4_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf"
     q5_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf"
     q6_k_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q6_K.gguf"
-    q8_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
+    # imatrix
+    iq1_m_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_M.gguf"
+    iq1_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ1_S.gguf"
+    iq2_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_S.gguf"
+    iq2_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_XS.gguf"
+    iq2_xxs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ2_XXS.gguf"
+    iq3_s_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ3_S.gguf"
+    iq3_xxs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ3_XXS.gguf"
+    iq4_xs_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf"
+    iq4_nl_gguf_model_id = "TinyLlama-1.1B-Chat-v1.0-IQ4_NL.gguf"
 
     q4_0_mistral_model_id = "mistral-7b-instruct-v0.2.Q4_0.gguf"
     q4_0_qwen2_model_id = "qwen1_5-0_5b-chat-q4_0.gguf"
@@ -87,6 +101,16 @@ def test_q3_k(self):
         EXPECTED_TEXT = "Hello, World!\n\n```\n<|user"
         self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
 
+    def test_q5_0(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_0_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_0_gguf_model_id).to(torch_device)
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello, World!\n\n5. Use a library"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
     def test_q5_k(self):
         tokenizer = AutoTokenizer.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id)
         model = AutoModelForCausalLM.from_pretrained(self.model_id, gguf_file=self.q5_k_gguf_model_id).to(torch_device)
@@ -151,6 +175,114 @@ def test_q8_0(self):
         EXPECTED_TEXT = "Hello, World!\n\n5. Use a library"
         self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
 
+    def test_iq1_s(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_s_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_s_gguf_model_id).to(
+            torch_device
+        )
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello, I'm a friend of mine, I"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
+    def test_iq1_m(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_m_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq1_m_gguf_model_id).to(
+            torch_device
+        )
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello, I am interested in purching a copy of"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
+    def test_iq2_s(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_s_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_s_gguf_model_id).to(
+            torch_device
+        )
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello World!\n\n```\n<|user|"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
+    def test_iq2_xs(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xs_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xs_gguf_model_id).to(
+            torch_device
+        )
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello World!\n\n```\n<|user|"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
+    def test_iq2_xxs(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xxs_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq2_xxs_gguf_model_id).to(
+            torch_device
+        )
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello, I'm a software engineer. I'"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
+    def test_iq3_s(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_s_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_s_gguf_model_id).to(
+            torch_device
+        )
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello, World!\n\n5. Python:\n"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
+    def test_iq3_xxs(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_xxs_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq3_xxs_gguf_model_id).to(
+            torch_device
+        )
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello, I am interested in your product. Can you"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
+    def test_iq4_xs(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_xs_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_xs_gguf_model_id).to(
+            torch_device
+        )
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello, world!\n\n5. Using a loop"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
+    def test_iq4_nl(self):
+        tokenizer = AutoTokenizer.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_nl_gguf_model_id)
+        model = AutoModelForCausalLM.from_pretrained(self.imatrix_model_id, gguf_file=self.iq4_nl_gguf_model_id).to(
+            torch_device
+        )
+
+        text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
+        out = model.generate(**text, max_new_tokens=10)
+
+        EXPECTED_TEXT = "Hello, world!\n\n5. Using a loop"
+        self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)
+
     def test_f16(self):
         tokenizer = AutoTokenizer.from_pretrained(self.tinyllama_model_id, gguf_file=self.f16_tinyllama_model_id)
         model = AutoModelForCausalLM.from_pretrained(