Skip to content

Commit

Permalink
fix ftype
Browse files Browse the repository at this point in the history
  • Loading branch information
ngxson committed Jul 8, 2024
1 parent 84288ff commit 7a83f20
Showing 1 changed file with 10 additions and 21 deletions.
31 changes: 10 additions & 21 deletions convert_lora_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,12 @@

import logging
import argparse
import contextlib
import json
import os
import re
import sys
import types
from enum import IntEnum
from pathlib import Path
from hashlib import sha256
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from typing import TYPE_CHECKING, Iterable, Iterator

import math
import numpy as np
import torch

if TYPE_CHECKING:
Expand All @@ -32,22 +25,17 @@

logger = logging.getLogger("lora-to-gguf")


def parse_args() -> argparse.Namespace:
all_models = ", ".join([arch for arch in Model._model_classes.keys()])
parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file")
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
parser.add_argument(
"--outfile", type=Path,
help="path to write to; default: based on input.",
)
parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
)
parser.add_argument(
"--arch", type=str,
help=f"Arch of the base model, must be one of: {all_models} (default: LlamaForCausalLM)",
default="LlamaForCausalLM"
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0",
)
parser.add_argument(
"--bigendian", action="store_true",
Expand All @@ -73,14 +61,13 @@ def parse_args() -> argparse.Namespace:
args = parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)

# FIXME: outtype is not working
ftype_map: dict[str, gguf.LlamaFileType] = {
"f32": gguf.LlamaFileType.ALL_F32,
"f16": gguf.LlamaFileType.MOSTLY_F16,
"bf16": gguf.LlamaFileType.MOSTLY_BF16,
"q8_0": gguf.LlamaFileType.MOSTLY_Q8_0,
"auto": gguf.LlamaFileType.GUESSED,
}
ftype = ftype_map[args.outtype]

dir_base_model = args.base
dir_lora = args.lora_path
Expand Down Expand Up @@ -110,7 +97,7 @@ def parse_args() -> argparse.Namespace:
logger.error(f"Model {hparams['architectures'][0]} is not supported")
sys.exit(1)

model_instance = model_class(dir_base_model, ftype_map[args.outtype], fname_out, args.bigendian, False, False, None)
model_instance = model_class(dir_base_model, ftype, fname_out, args.bigendian, False, False, None)
logger.info("Set model parameters")
model_instance.set_gguf_parameters()

Expand Down Expand Up @@ -140,16 +127,18 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
# overwrite method
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# TODO: This will not take into account tensor transformations
return [(name, data_torch)]

# overwrite method
def extra_f16_tensors(self, name: str, new_name: str, bid: int | None, n_dims: int) -> bool:
del name, new_name, bid, n_dims # unused
return True
return ftype != gguf.LlamaFileType.ALL_F32

model_instance.get_tensors = types.MethodType(get_tensors, model_instance)
model_instance.modify_tensors = types.MethodType(modify_tensors, model_instance)
model_instance.extra_f16_tensors = types.MethodType(extra_f16_tensors, model_instance)

model_instance.gguf_writer.add_quantization_version(gguf.GGML_QUANT_VERSION)
logger.info("Exporting model...")
model_instance.write()
Expand Down

0 comments on commit 7a83f20

Please sign in to comment.