1
+ import functools
1
2
import logging
2
3
import os
4
+ import random
3
5
import sys
6
+ from itertools import chain
7
+ from typing import Dict
4
8
5
9
import datasets
6
10
import determined as det
7
11
import evaluate
12
+ import numpy as np
8
13
import torch
9
14
import transformers
15
+ import wandb
16
+
10
17
from determined .transformers import DetCallback
11
18
from peft import AutoPeftModelForCausalLM , LoraConfig , get_peft_model
12
- from transformers import AutoModelForCausalLM , AutoTokenizer , Trainer , TrainingArguments
19
+ from transformers import AutoModelForCausalLM , AutoTokenizer , Trainer , TrainingArguments , get_linear_schedule_with_warmup
13
20
from trl import DataCollatorForCompletionOnlyLM
14
21
15
22
from chat_format import get_chat_format , get_response_template_ids , set_special_tokens
18
25
logger = logging .getLogger (__name__ )
19
26
20
27
21
- def get_tokenizer (model_name , model_commit_hash ):
28
+ def get_tokenizer (model_name , model_commit_hash , hparams ):
22
29
tokenizer = AutoTokenizer .from_pretrained (
23
30
model_name ,
24
31
padding_side = "right" ,
25
32
truncation_side = "right" ,
26
33
revision = model_commit_hash ,
34
+ token = hparams ["hf_token" ],
27
35
)
28
36
set_special_tokens (tokenizer , model_name )
29
37
return tokenizer
30
38
31
39
40
+ def standardize_lora_init (lora_layer , alpha : int ):
41
+ self_attn = lora_layer .self_attn
42
+ q_proj = self_attn .q_proj .lora_A .default
43
+ v_proj = self_attn .v_proj .lora_A .default
44
+ with torch .no_grad ():
45
+ sd_q = q_proj .state_dict ()
46
+ sd_q ['weight' ] = sd_q ['weight' ] / alpha
47
+ q_proj .load_state_dict (sd_q )
48
+ sd_v = v_proj .state_dict ()
49
+ sd_v ['weight' ] = sd_v ['weight' ] / alpha
50
+ v_proj .load_state_dict (sd_v )
51
+
52
+
32
53
def get_model_and_tokenizer (model_name , use_lora , hparams , inference = False , device_map = "auto" , model_commit_hash = None ):
33
54
if inference :
34
55
if use_lora :
@@ -47,22 +68,55 @@ def get_model_and_tokenizer(model_name, use_lora, hparams, inference=False, devi
47
68
model_name ,
48
69
torch_dtype = torch .bfloat16 ,
49
70
revision = model_commit_hash ,
71
+ token = hparams ["hf_token" ],
50
72
)
73
+ model .enable_input_require_grads ()
51
74
52
75
if use_lora :
53
76
r = hparams ["r" ]
54
- lora_alpha = r * hparams ["lora_alpha_in_r " ]
77
+ lora_alpha = hparams ["lora_alpha " ]
55
78
peft_config = LoraConfig (
56
79
task_type = "CAUSAL_LM" ,
57
80
inference_mode = False ,
58
81
r = r ,
59
82
lora_alpha = lora_alpha ,
60
83
lora_dropout = hparams ["lora_dropout" ],
84
+ use_rslora = hparams ["use_rslora" ]
61
85
)
62
86
63
87
model = get_peft_model (model , peft_config )
64
88
65
- tokenizer = get_tokenizer (model_name , model_commit_hash = model_commit_hash )
89
+ lora_a = model .base_model .model .model .layers [0 ].self_attn .q_proj .lora_A .default
90
+ print ("LoRA a at initialization, before rescaling, layer 0, q_proj:" )
91
+ print (lora_a .state_dict ())
92
+ lora_a = model .base_model .model .model .layers [31 ].self_attn .q_proj .lora_A .default
93
+ print ("LoRA a at initialization, before rescaling, layer 31, q_proj:" )
94
+ print (lora_a .state_dict ())
95
+ lora_a = model .base_model .model .model .layers [0 ].self_attn .v_proj .lora_A .default
96
+ print ("LoRA a at initialization, before rescaling, layer 0, v_proj:" )
97
+ print (lora_a .state_dict ())
98
+ lora_a = model .base_model .model .model .layers [31 ].self_attn .v_proj .lora_A .default
99
+ print ("LoRA a at initialization, before rescaling, layer 31, v_proj:" )
100
+ print (lora_a .state_dict ())
101
+
102
+ if hparams ["custom_scale_init" ]:
103
+ for l in model .base_model .model .model .layers :
104
+ standardize_lora_init (l , lora_alpha )
105
+
106
+ lora_a = model .base_model .model .model .layers [0 ].self_attn .q_proj .lora_A .default
107
+ print ("LoRA a at initialization, after rescaling, layer 0, q_proj:" )
108
+ print (lora_a .state_dict ())
109
+ lora_a = model .base_model .model .model .layers [31 ].self_attn .q_proj .lora_A .default
110
+ print ("LoRA a at initialization, after rescaling, layer 31, q_proj:" )
111
+ print (lora_a .state_dict ())
112
+ lora_a = model .base_model .model .model .layers [0 ].self_attn .v_proj .lora_A .default
113
+ print ("LoRA a at initialization, after rescaling, layer 0, v_proj:" )
114
+ print (lora_a .state_dict ())
115
+ lora_a = model .base_model .model .model .layers [31 ].self_attn .v_proj .lora_A .default
116
+ print ("LoRA a at initialization, after rescaling, layer 31, v_proj:" )
117
+ print (lora_a .state_dict ())
118
+
119
+ tokenizer = get_tokenizer (model_name , model_commit_hash = model_commit_hash , hparams = hparams )
66
120
return model , tokenizer
67
121
68
122
@@ -73,6 +127,23 @@ def fn(formatted):
73
127
return fn
74
128
75
129
130
+ def group_texts (examples , block_size ) -> Dict :
131
+ # Concatenate all texts.
132
+ concatenated_examples = {k : list (chain (* examples [k ])) for k in examples .keys ()}
133
+ total_length = len (concatenated_examples [list (examples .keys ())[0 ]])
134
+ # We drop the small remainder, we could add padding if the model supported it instead
135
+ # of this drop, you can customize this part to your needs.
136
+ if total_length >= block_size :
137
+ total_length = (total_length // block_size ) * block_size
138
+ # Split by chunks of max_len.
139
+ result = {
140
+ k : [t [i : i + block_size ] for i in range (0 , total_length , block_size )]
141
+ for k , t in concatenated_examples .items ()
142
+ }
143
+ result ["labels" ] = result ["input_ids" ].copy ()
144
+ return result
145
+
146
+
76
147
def preprocess_logits_for_metrics (logits , labels ):
77
148
if isinstance (logits , tuple ):
78
149
# Depending on the model and config, logits may contain extra tensors,
@@ -105,10 +176,18 @@ def tokenize(element):
105
176
}
106
177
107
178
dataset = load_or_create_dataset (hparams ["dataset_subset" ])
179
+ block_size = hparams ["block_size" ]
108
180
column_names = list (dataset ["train" ].features )
109
181
for k in dataset .keys ():
110
182
dataset [k ] = dataset [k ].map (tokenize , remove_columns = column_names )
111
-
183
+ if hparams ["group_text" ]:
184
+ with training_args .main_process_first (desc = "grouping texts together" , local = False ):
185
+ dataset = dataset .map (
186
+ functools .partial (group_texts , block_size = block_size ),
187
+ batched = True ,
188
+ desc = f"Grouping texts in chunks of { block_size } " ,
189
+ )
190
+
112
191
response_template_ids = get_response_template_ids (tokenizer , model_name )
113
192
collator = DataCollatorForCompletionOnlyLM (
114
193
response_template_ids , tokenizer = tokenizer
@@ -151,6 +230,18 @@ def compute_metrics(eval_preds):
151
230
152
231
trainer .train ()
153
232
233
+ def set_seed (seed : int = 42 ) -> None :
234
+ np .random .seed (seed )
235
+ random .seed (seed )
236
+ torch .manual_seed (seed )
237
+ torch .cuda .manual_seed (seed )
238
+ # When running on the CuDNN backend, two further options must be set
239
+ torch .backends .cudnn .deterministic = True
240
+ torch .backends .cudnn .benchmark = False
241
+ # Set a fixed value for the hash seed
242
+ os .environ ["PYTHONHASHSEED" ] = str (seed )
243
+ print (f"Random seed set as { seed } " )
244
+
154
245
155
246
if __name__ == "__main__" :
156
247
# Setup logging
@@ -169,12 +260,19 @@ def compute_metrics(eval_preds):
169
260
hparams = info .trial .hparams
170
261
171
262
if "hf_token" in hparams :
263
+ print ("SWY logged flow triggered" )
264
+ hf_token = hparams ["hf_token" ]
265
+ print (f"SWY token is { hf_token } " )
172
266
import huggingface_hub
173
-
174
267
huggingface_hub .login (token = hparams ["hf_token" ])
175
268
176
269
if hparams ["training_args" ]["deepspeed" ]:
177
- hparams ["training_args" ]["deepspeed" ] = "ds_configs/ds_config_stage_3.json"
270
+ if not hparams ["use_adam" ]:
271
+ hparams ["training_args" ]["deepspeed" ] = "ds_configs/ds_config_stage_3.json"
272
+ print ("swy not using adam" )
273
+ else :
274
+ hparams ["training_args" ]["deepspeed" ] = "ds_configs/ds_config_stage_3_adam.json"
275
+ print ("swy using adam" )
178
276
179
277
training_args = TrainingArguments (** hparams ["training_args" ])
180
278
if training_args .deepspeed :
@@ -186,8 +284,50 @@ def compute_metrics(eval_preds):
186
284
distributed = det .core .DistributedContext .from_deepspeed ()
187
285
else :
188
286
distributed = det .core .DistributedContext .from_torch_distributed ()
189
-
287
+
288
+ random_seed = 42
289
+
190
290
with det .core .init (distributed = distributed ) as core_context :
291
+ if core_context .distributed .rank == 0 :
292
+ wandb .login (key = hparams ["wandb_key" ])
293
+ import uuid
294
+ # Generate a UUID
295
+ my_uuid = uuid .uuid4 ()
296
+ # Convert UUID to string
297
+ uuid_str = str (my_uuid )[:5 ]
298
+ r = hparams ["r" ]
299
+ lora_alpha = hparams ["lora_alpha" ]
300
+ lora_dropout = hparams ["lora_dropout" ]
301
+ dataset_subset = hparams ["dataset_subset" ]
302
+ lr = str (hparams ["training_args" ]["learning_rate" ])
303
+ use_rslora = False
304
+ if "use_rslora" in hparams :
305
+ use_rslora = hparams ["use_rslora" ]
306
+ optimizer = "adamW"
307
+ if "use_adam" in hparams and hparams ["use_adam" ]:
308
+ optimizer = "adam"
309
+ run_name = f"test_lora_blog_{ dataset_subset } _r_{ r } _alpha_{ lora_alpha } _dropout_{ lora_dropout } _lr_{ lr } _seed_{ random_seed } _opt_{ optimizer } "
310
+ if use_rslora :
311
+ run_name += "_rslora"
312
+ run_name += f"_{ uuid_str } "
313
+ run = wandb .init (
314
+ project = "lora-blog-v3" ,
315
+ name = run_name ,
316
+ config = {
317
+ "r" :hparams ["r" ],
318
+ "lora_alpha" :hparams ["lora_alpha" ],
319
+ "dropout" :hparams ["lora_dropout" ],
320
+ "dataset_subset" :hparams ["dataset_subset" ],
321
+ "model" :hparams ["model" ],
322
+ "lr" : lr ,
323
+ "seed" : random_seed ,
324
+ "optimizer" : optimizer ,
325
+ "use_rslora" : use_rslora
326
+ }
327
+ )
328
+
329
+ set_seed (random_seed )
330
+
191
331
det_callback = DetCallback (
192
332
core_context ,
193
333
training_args ,
0 commit comments