-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathmodel_utils.py
173 lines (146 loc) · 6.05 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import math
import torch
from transformers import (
AutoConfig,
AutoModel,
)
from huggingface_hub import snapshot_download
from transformers.deepspeed import HfDeepSpeedConfig
from .reward_model import RewardModel
from ..utils import load_state_dict_into_model
from ..utils import print_rank_0
def configure_dropout(model_config, dropout):
if dropout is not None:
for key in ('dropout', 'attention_dropout', 'hidden_dropout',
'activation_dropout'):
if hasattr(model_config, key):
print(f"Setting model_config.{key} to {dropout}")
setattr(model_config, key, dropout)
def causal_lm_model_to_fp32_loss(model):
""" Convert CausalLM model to calculate loss in fp32 """
def causal_lm_forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**deprecated_arguments,
):
kwargs = dict() if model.config.model_type == "llama" else dict(
head_mask=head_mask)
output = model.__original_forward__(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs)
return_dict = isinstance(output, dict)
lm_logits = output.logits if return_dict else output[0]
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].float().contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length))
if not return_dict:
# re-pack output with fp32 loss
return ((loss, ) + output) if loss is not None else output
output.loss = loss
return output
model.__original_forward__ = model.forward
model.forward = causal_lm_forward
def create_hf_model(model_class,
model_name_or_path,
tokenizer,
ds_config=None,
rlhf_training=False,
dropout=None):
model_config = AutoConfig.from_pretrained(model_name_or_path)
configure_dropout(model_config, dropout)
# Note: dschf is defined in function scope to avoid global effects
# https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration
if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
dschf = HfDeepSpeedConfig(ds_config)
else:
dschf = None
if rlhf_training:
# the weight loading is handled by create critic model
model = model_class.from_config(model_config)
else:
model = model_class.from_pretrained(
model_name_or_path,
from_tf=bool(".ckpt" in model_name_or_path),
config=model_config)
model.config.end_token_id = tokenizer.eos_token_id
model.config.pad_token_id = model.config.eos_token_id
model.resize_token_embeddings(int(
8 *
math.ceil(len(tokenizer) / 8.0))) # make the vocab size multiple of 8
return model
def create_critic_model(model_name_or_path,
tokenizer,
ds_config,
num_padding_at_beginning=0,
rlhf_training=False,
dropout=None,
zero_stage=0,
compute_fp32_loss=False):
# OPT model family always put a padding token at the beginning of the sequence,
# we did not see this in other models but not sure if it is a general rule
import time
start = time.time()
critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer,
ds_config, rlhf_training, dropout)
end = time.time()
print_rank_0(f">Creating model from_config took {end - start} seconds",
None)
critic_model = RewardModel(
critic_model,
tokenizer,
num_padding_at_beginning=num_padding_at_beginning,
compute_fp32_loss=compute_fp32_loss)
if rlhf_training:
# load critic model from checkpoint
if not os.path.isdir(model_name_or_path):
model_name_or_path = snapshot_download(model_name_or_path)
model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin')
assert os.path.exists(
model_ckpt_path
), f"Cannot find model checkpoint at {model_ckpt_path}"
start = time.time()
model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu')
end = time.time()
print_rank_0(f">Creating model from_config took {end - start} seconds",
None)
# load critic model from checkpoint with zero-stage 3 compatibility
# this functionality may be moved to DS checkpoint load API in future
start = time.time()
load_state_dict_into_model(critic_model,
model_ckpt_state_dict,
"",
zero_stage=zero_stage)
end = time.time()
print_rank_0(f">Creating model from_config took {end - start} seconds",
None)
return critic_model