-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathFinetune-Bloom7B-tagger.py
124 lines (100 loc) · 3.84 KB
/
Finetune-Bloom7B-tagger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import torch
import torch.nn as nn
import bitsandbytes as bnb
import transformers as transformers
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
#Setup the model
model_id="bigscience/bloom-1b7"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
print(model.get_memory_footprint())
'''
Change the compute dtype
The compute dtype is used to change the dtype that will be used during computation.
For example, hidden states could be in float32 but computation can be set to bf16 for speedups. By default, the compute dtype is set to float32.
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
'''
'''
Using NF4 (Normal Float 4) data type
You can also use the NF4 data type, which is a new 4bit datatype adapted for weights that have been initialized using a normal distribution. For that run:
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)
'''
'''
Use nested quantization for more memory efficient inference
We also advise users to use the nested quantization technique. This saves more memory at no additional performance - from our empirical observations,
this enables fine-tuning llama-13b model on an NVIDIA-T4 16GB with a sequence length of 1024, batch size of 1 and gradient accumulation steps of 4.
double_quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=double_quant_config)
'''
#Freezing the original weights
for param in model.parameters():
param.requires_grad = False
if param.ndim ==1:
param.data = param.data.to(torch.float32)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)
#Setting up the LoRa Adapters
def print_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias = 'none',
task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)
print_trainable_parameters(model)
data = load_dataset("Abirate/english_quotes")
def merge_colunms(example):
example['prediction'] = example['quote'] + " ->: " + str(example["tags"])
return example
data['train'] = data['train'].map(merge_colunms)
print(data['train']["prediction"][:5])
print(data['train'][0])
data = data.map(lambda samples: tokenizer(samples['prediction']), batched=True)
print(data)
#Training
trainer = transformers.Trainer(
model=model,
train_dataset=data['train'],
args=transformers.TrainingArguments(
per_gpu_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=100,
max_steps=200,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir='outputs'
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False
trainer.train()
model.push_to_hub("meetrais/bloom-7b1-lora-tagger",
token="HuggingFace-app-key",
commit_message="basic training",
private=True)