Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor bug fix in caching, and add logging in Longbench implementation #2

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from tqdm import tqdm
from palu.decomposition import compress_model_whiten

import os
os.environ["HF_TOKEN"] = "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"

def compress(args):
# set seed
set_seed(args.seed)
Expand Down
1 change: 0 additions & 1 deletion longbench_utils/config/model2maxlen.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"internlm-7b-8k": 7500,
"chatglm2-6b": 31500,
"chatglm2-6b-32k": 31500,
"vicuna-v1.5-7b-16k": 15500,
"LLaMA-2-7B-32K": 7500,
"llama-7b": 4096,
"Llama-2-7b-chat-hf": 4096,
Expand Down
6 changes: 4 additions & 2 deletions palu/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_whiten_scale_matrix(model, tokenizer, args, dev):
seqlen=2048
)
cache_file = f"cache/whiten/{model_id.replace('/','_')}_w2_scaling_matrices_fp16.pt"
os.makedirs("cache/whiten", exist_ok=True)
"""
cache format:
[
Expand Down Expand Up @@ -71,7 +72,8 @@ def get_whiten_scale_matrix(model, tokenizer, args, dev):
# Here, inference are performed in an layer-wise manner.
use_cache = model.config.use_cache
model.config.use_cache = False
if "llama" in model_id or "mistral" in model_id or "vicuna" in model_id:
#FIXME: This is not a good implementation...
if "llama" in model_id or "mistral" in model_id or "vicuna" in model_id or "longchat":
layers = model.model.layers
elif "opt" in model_id:
layers = model.model.decoder.layers
Expand Down Expand Up @@ -186,7 +188,7 @@ def hook(module, input, output):
model.config.use_cache = use_cache
if args.use_cache:
torch.save(scaling_matrices, cache_file)
logger.info(f"Save the whiten scale matrix dict to: {cache_file}", fg="yellow")
logger.info(f"Save the whiten scale matrix dict to: {cache_file}")

def compress_model_whiten(model, tokenizer, args, dev, selection_result):
# NOTE(brian1009): Prepare whiten scaling matrix
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ numpy==1.26.4
omegaconf==2.3.0
torch==2.1.2
tqdm==4.66.1
transformers==4.43.1
transformers==4.43.3
yacs==0.1.8
accelerate==0.32.1
protobuf==4.25.1
jieba==0.42.1
rouge
fuzzywuzzy
loguru==0.7.2
triton==3.0.0
fastchat==0.1.0
#triton==3.0.0
73 changes: 55 additions & 18 deletions run_long_bench.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import random
import argparse
import time
import json
from loguru import logger
from datetime import datetime
os.environ["WANDB_DISABLED"] = "true"

from longbench_utils import scorer, MODEL2MAXLEN, DATASET2PROMPT, DATASET2MAXLEN
Expand All @@ -23,16 +25,38 @@ def post_process(response, model_name):
response = response.split("<eoa>")[0]
return response

# This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name):
# Copy from KIVI
if "longchat" in model_name.lower() or "vicuna" in model_name.lower():
from fastchat.model import get_conversation_template
conv = get_conversation_template("vicuna")
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
elif "mistral-v0.2-instruct" in model_name.lower():
messages = [
{
"role": "user",
"content": prompt
}
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt

def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name):
preds = []
for json_obj in tqdm(data):
prompt = prompt_format.format(**json_obj)
# truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]

if len(tokenized_prompt) > max_length:
half = int(max_length/2)
prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)

if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
prompt = build_chat(tokenizer, prompt, model_name)

input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
context_length = input.input_ids.shape[-1]
Expand All @@ -53,12 +77,14 @@ def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset
num_beams=1,
do_sample=False,
temperature=1.0,
pad_token_id=tokenizer.eos_token_id,
)[0]
pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
pred = post_process(pred, model_name)
preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]})
return preds


def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
Expand All @@ -72,26 +98,28 @@ def main(args):
model2maxlen = MODEL2MAXLEN
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)
model, tokenizer = load_model_and_tokenizer(args.model_name_or_path, use_flash_attn2=args.flash2)
configure_latent_quantizer(
model, n_bits=args.lt_bits,
group_size=args.lt_group_size,
sym=args.lt_sym,
clip_ratio=args.lt_clip_ratio,
hadamard=args.lt_hadamard
)
model_name = args.model_name_or_path.split("/")[-1]
orig_model_name = "Mistral-7B-v0.1" if "mistral" in model_name.lower() else "Llama-2-7b-chat-hf"
if "mistral" in model_name.lower():
orig_model_name = "Mistral-7B-v0.1"
elif "llama" in model_name.lower():
orig_model_name = "Llama-2-7b-chat-hf"
elif "vicuna" in model_name.lower():
orig_model_name = "vicuna-v1.5-7b-16k"

#NOTE(brian1009): This is a hack to get the model name
# We assume the model name is the inside the last part of the path
# and the Palu's compression information is follow by the model name with a "_"
# Hence, we split the path by "/" and then keep only the first part by "_"
# Example: Mistral-7B-Instruct-v0.2_ratio-0.7_gs-4-fisher_uniform
raw_model_name = args.model_name_or_path.split("/")[-1]
model_type = args.model_name_or_path.split("/")[-1].split('_')[0]

model.eval()
max_length = model2maxlen[orig_model_name]
logger.info(f"Running model: {model_name}")
if not model_type in model2maxlen:
raise ValueError(f"Model {model_type} not supported")

max_length = model2maxlen[model_type]
logger.info(f"Running model: {raw_model_name}")
logger.info(f"Max length: {max_length}")
datasets = args.datasets
# we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
Expand All @@ -104,15 +132,17 @@ def main(args):
results = {}

for dataset in datasets:
logger.info("Evaluating dataset: {}".format(dataset))
start_time = time.time()
data = load_dataset('THUDM/LongBench', dataset, split='test')
prompt_format = dataset2prompt[dataset]
max_gen = dataset2maxlen[dataset]
preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name)
preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_type)
end_time = time.time()
elapsed_time = end_time - start_time
logger.info(f"Elapsed time for dataset {dataset}: {elapsed_time/60} minutes")


# calculate score
predictions, answers, lengths = [], [], []
for pred in preds:
predictions.append(pred["pred"])
Expand All @@ -123,11 +153,16 @@ def main(args):
score = scorer(dataset, predictions, answers, all_classes)
logger.info(f"dataset: {dataset}")
logger.info(f"score: {score}")
results[dataset] = {"score": score}

os.makedirs("results/Longbench", exist_ok=True)
with open(f"results/Longbench/{model_name}.txt", "w") as f:
f.write(str(results))
# Log the results of each datasets
with open(f"results/Longbench/{raw_model_name}_bits_{args.lt_bits}.json", "a") as f:
data_to_log = {
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"dataset": dataset,
"score": score,
}
json.dump(data_to_log, f)
f.write("\n")

if __name__ == '__main__':
seed_everything(42)
Expand All @@ -147,6 +182,8 @@ def main(args):

logger.remove()
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level="INFO" if not args.verbose else "DEBUG")
#Create directory to log evaluation results.
os.makedirs("results/Longbench", exist_ok=True)

main(args)

Expand Down
13 changes: 8 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import random, torch
from functools import reduce
from palu.model import HeadwiseLowRankModule
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig


def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
Expand Down Expand Up @@ -47,9 +48,9 @@ def get_module_by_name(module, module_name):
def dump_to_huggingface_repos(model, tokenizer, save_path, args):
tokenizer.save_pretrained(save_path)
#model.generation_config = Gene
#if "vicuna" in model.config._name_or_path.lower():
#if "vicuna" in model.config._name_or_path.lower() or "longchat" in model.config._name_or_path.lower():
#NOTE(brian1009): Ad-hoc fixing the bug in Vicuna
#model.config.generation_config = GenerationConfig(temperature=1.0, top_p=1.0)
# model.config.generation_config = GenerationConfig(temperature=1.0, top_p=1.0)
model.save_pretrained(save_path)
config = model.config.to_dict()
config["head_wise_ranks"] = {}
Expand All @@ -72,7 +73,7 @@ def dump_to_huggingface_repos(model, tokenizer, save_path, args):
json.dump(config, open(save_path + "/config.json", "w"), indent=2)


def load_model_and_tokenizer(model_name_or_path):
def load_model_and_tokenizer(model_name_or_path, use_flash_attn2=False):
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
trust_remote_code=True,
Expand All @@ -82,11 +83,12 @@ def load_model_and_tokenizer(model_name_or_path):
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
attn_implementation="flash_attention_2" if use_flash_attn2 else "sdpa",
)
model.eval()
# Fix the bug in generation configs
#TODO: Add reference to the issue that also faced this bug
if "vicuna" in model.config._name_or_path.lower():
if "vicuna" in model.config._name_or_path.lower() or "longchat" in model.config._name_or_path.lower():
model.generation_config.do_sample = True

return model, tokenizer
Expand All @@ -101,4 +103,5 @@ def add_common_args(parser: argparse.ArgumentParser):
parser.add_argument('--lt_sym', action='store_true', help='Symmetric quantization for low_rank latents')
parser.add_argument('--lt_clip_ratio', type=float, help='Clip ratio for low_rank latents', default=1.0)
parser.add_argument('--lt_hadamard', action='store_true', help='Apply Hadamard transform to low_rank latents')
parser.add_argument('--flash2', action='store_true', help='whether to use flash-attention2')
return parser