-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlora_utils.py
52 lines (44 loc) · 1.61 KB
/
lora_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import glob
import os
from os.path import exists, join, isdir
import shutil
import sys
from typing import Optional, Dict, Sequence, List
import torch
import transformers
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
DEFAULT_PAD_TOKEN = "[PAD]"
def print_trainable_parameters(args, model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
if args.bits == 4:
trainable_params /= 2
print(
f"trainable params: {trainable_params} || "
f"all params: {all_param} || "
f"trainable: {100 * trainable_params / all_param}"
)
def get_last_checkpoint(checkpoint_dir):
if isdir(checkpoint_dir):
is_completed = exists(join(checkpoint_dir, "completed"))
if is_completed:
return None, True # already finished
max_step = 0
for filename in os.listdir(checkpoint_dir):
if isdir(join(checkpoint_dir, filename)) and filename.startswith(
"checkpoint"
):
max_step = max(max_step, int(filename.replace("checkpoint-", "")))
if max_step == 0:
return None, is_completed # training started, but no checkpoint
checkpoint_dir = join(checkpoint_dir, f"checkpoint-{max_step}")
print(f"Found a previous checkpoint at: {checkpoint_dir}")
return checkpoint_dir, is_completed # checkpoint found!
return None, False # first training