Skip to content

Commit

Permalink
Merge pull request #2 from shadowpa0327/enhance/long_bench_eval
Browse files Browse the repository at this point in the history
minor bug fix in caching, and add logging in Longbench implementation
  • Loading branch information
shadowpa0327 authored Jul 26, 2024
2 parents d17312e + 367e716 commit 7fae285
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 28 deletions.
3 changes: 3 additions & 0 deletions compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from tqdm import tqdm
from palu.decomposition import compress_model_whiten

import os
os.environ["HF_TOKEN"] = "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"

def compress(args):
# set seed
set_seed(args.seed)
Expand Down
1 change: 0 additions & 1 deletion longbench_utils/config/model2maxlen.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"internlm-7b-8k": 7500,
"chatglm2-6b": 31500,
"chatglm2-6b-32k": 31500,
"vicuna-v1.5-7b-16k": 15500,
"LLaMA-2-7B-32K": 7500,
"llama-7b": 4096,
"Llama-2-7b-chat-hf": 4096,
Expand Down
6 changes: 4 additions & 2 deletions palu/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_whiten_scale_matrix(model, tokenizer, args, dev):
seqlen=2048
)
cache_file = f"cache/whiten/{model_id.replace('/','_')}_w2_scaling_matrices_fp16.pt"
os.makedirs("cache/whiten", exist_ok=True)
"""
cache format:
[
Expand Down Expand Up @@ -71,7 +72,8 @@ def get_whiten_scale_matrix(model, tokenizer, args, dev):
# Here, inference are performed in an layer-wise manner.
use_cache = model.config.use_cache
model.config.use_cache = False
if "llama" in model_id or "mistral" in model_id or "vicuna" in model_id:
#FIXME: This is not a good implementation...
if "llama" in model_id or "mistral" in model_id or "vicuna" in model_id or "longchat":
layers = model.model.layers
elif "opt" in model_id:
layers = model.model.decoder.layers
Expand Down Expand Up @@ -186,7 +188,7 @@ def hook(module, input, output):
model.config.use_cache = use_cache
if args.use_cache:
torch.save(scaling_matrices, cache_file)
logger.info(f"Save the whiten scale matrix dict to: {cache_file}", fg="yellow")
logger.info(f"Save the whiten scale matrix dict to: {cache_file}")

def compress_model_whiten(model, tokenizer, args, dev, selection_result):
# NOTE(brian1009): Prepare whiten scaling matrix
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ numpy==1.26.4
omegaconf==2.3.0
torch==2.1.2
tqdm==4.66.1
transformers==4.43.1
transformers==4.43.3
yacs==0.1.8
accelerate==0.32.1
protobuf==4.25.1
jieba==0.42.1
rouge
fuzzywuzzy
loguru==0.7.2
triton==3.0.0
fastchat==0.1.0
#triton==3.0.0
73 changes: 55 additions & 18 deletions run_long_bench.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import random
import argparse
import time
import json
from loguru import logger
from datetime import datetime
os.environ["WANDB_DISABLED"] = "true"

from longbench_utils import scorer, MODEL2MAXLEN, DATASET2PROMPT, DATASET2MAXLEN
Expand All @@ -23,16 +25,38 @@ def post_process(response, model_name):
response = response.split("<eoa>")[0]
return response

# This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name):
# Copy from KIVI
if "longchat" in model_name.lower() or "vicuna" in model_name.lower():
from fastchat.model import get_conversation_template
conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
elif "mistral-v0.2-instruct" in model_name.lower():
messages = [
{
"role": "user",
"content": prompt
}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt

def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name):
preds = []
for json_obj in tqdm(data):
prompt = prompt_format.format(**json_obj)
# truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]

if len(tokenized_prompt) > max_length:
half = int(max_length/2)
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)

if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
prompt = build_chat(tokenizer, prompt, model_name)

input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
context_length = input.input_ids.shape[-1]
Expand All @@ -53,12 +77,14 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset
num_beams=1,
do_sample=False,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
)[0]
pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
pred = post_process(pred, model_name)
preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]})
return preds


def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand All @@ -72,26 +98,28 @@ def main(args):
model2maxlen = MODEL2MAXLEN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
model, tokenizer = load_model_and_tokenizer(args.model_name_or_path, use_flash_attn2=args.flash2)
configure_latent_quantizer(
model, n_bits=args.lt_bits,
group_size=args.lt_group_size,
sym=args.lt_sym,
clip_ratio=args.lt_clip_ratio,
hadamard=args.lt_hadamard
)
model_name = args.model_name_or_path.split("/")[-1]
orig_model_name = "Mistral-7B-v0.1" if "mistral" in model_name.lower() else "Llama-2-7b-chat-hf"
if "mistral" in model_name.lower():
orig_model_name = "Mistral-7B-v0.1"
elif "llama" in model_name.lower():
orig_model_name = "Llama-2-7b-chat-hf"
elif "vicuna" in model_name.lower():
orig_model_name = "vicuna-v1.5-7b-16k"

#NOTE(brian1009): This is a hack to get the model name
# We assume the model name is the inside the last part of the path
# and the Palu's compression information is follow by the model name with a "_"
# Hence, we split the path by "/" and then keep only the first part by "_"
# Example: Mistral-7B-Instruct-v0.2_ratio-0.7_gs-4-fisher_uniform
raw_model_name = args.model_name_or_path.split("/")[-1]
model_type = args.model_name_or_path.split("/")[-1].split('_')[0]

model.eval()
max_length = model2maxlen[orig_model_name]
logger.info(f"Running model: {model_name}")
if not model_type in model2maxlen:
raise ValueError(f"Model {model_type} not supported")

max_length = model2maxlen[model_type]
logger.info(f"Running model: {raw_model_name}")
logger.info(f"Max length: {max_length}")
datasets = args.datasets
# we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
Expand All @@ -104,15 +132,17 @@ def main(args):
results = {}

for dataset in datasets:
logger.info("Evaluating dataset: {}".format(dataset))
start_time = time.time()
data = load_dataset('THUDM/LongBench', dataset, split='test')
prompt_format = dataset2prompt[dataset]
max_gen = dataset2maxlen[dataset]
preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name)
preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_type)
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"Elapsed time for dataset {dataset}: {elapsed_time/60} minutes")


# calculate score
predictions, answers, lengths = [], [], []
for pred in preds:
predictions.append(pred["pred"])
Expand All @@ -123,11 +153,16 @@ def main(args):
score = scorer(dataset, predictions, answers, all_classes)
logger.info(f"dataset: {dataset}")
logger.info(f"score: {score}")
results[dataset] = {"score": score}

os.makedirs("results/Longbench", exist_ok=True)
with open(f"results/Longbench/{model_name}.txt", "w") as f:
f.write(str(results))
# Log the results of each datasets
with open(f"results/Longbench/{raw_model_name}_bits_{args.lt_bits}.json", "a") as f:
data_to_log = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"dataset": dataset,
"score": score,
}
json.dump(data_to_log, f)
f.write("\n")

if __name__ == '__main__':
seed_everything(42)
Expand All @@ -147,6 +182,8 @@ def main(args):

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level="INFO" if not args.verbose else "DEBUG")
#Create directory to log evaluation results.
os.makedirs("results/Longbench", exist_ok=True)

main(args)

Expand Down
13 changes: 8 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import random, torch
from functools import reduce
from palu.model import HeadwiseLowRankModule
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig


def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
Expand Down Expand Up @@ -47,9 +48,9 @@ def get_module_by_name(module, module_name):
def dump_to_huggingface_repos(model, tokenizer, save_path, args):
tokenizer.save_pretrained(save_path)
#model.generation_config = Gene
#if "vicuna" in model.config._name_or_path.lower():
#if "vicuna" in model.config._name_or_path.lower() or "longchat" in model.config._name_or_path.lower():
#NOTE(brian1009): Ad-hoc fixing the bug in Vicuna
#model.config.generation_config = GenerationConfig(temperature=1.0, top_p=1.0)
# model.config.generation_config = GenerationConfig(temperature=1.0, top_p=1.0)
model.save_pretrained(save_path)
config = model.config.to_dict()
config["head_wise_ranks"] = {}
Expand All @@ -72,7 +73,7 @@ def dump_to_huggingface_repos(model, tokenizer, save_path, args):
json.dump(config, open(save_path + "/config.json", "w"), indent=2)


def load_model_and_tokenizer(model_name_or_path):
def load_model_and_tokenizer(model_name_or_path, use_flash_attn2=False):
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True,
Expand All @@ -82,11 +83,12 @@ def load_model_and_tokenizer(model_name_or_path):
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
attn_implementation="flash_attention_2" if use_flash_attn2 else "sdpa",
)
model.eval()
# Fix the bug in generation configs
#TODO: Add reference to the issue that also faced this bug
if "vicuna" in model.config._name_or_path.lower():
if "vicuna" in model.config._name_or_path.lower() or "longchat" in model.config._name_or_path.lower():
model.generation_config.do_sample = True

return model, tokenizer
Expand All @@ -101,4 +103,5 @@ def add_common_args(parser: argparse.ArgumentParser):
parser.add_argument('--lt_sym', action='store_true', help='Symmetric quantization for low_rank latents')
parser.add_argument('--lt_clip_ratio', type=float, help='Clip ratio for low_rank latents', default=1.0)
parser.add_argument('--lt_hadamard', action='store_true', help='Apply Hadamard transform to low_rank latents')
parser.add_argument('--flash2', action='store_true', help='whether to use flash-attention2')
return parser

0 comments on commit 7fae285

Please sign in to comment.