-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New TF loading weights #8490
New TF loading weights #8490
Changes from 3 commits
c2010c0
3d842c3
b324000
8a89d8c
e327cfe
afa0411
49669ed
6567037
08f998c
ac6785f
315249e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -236,15 +236,15 @@ def compute_loss(self, labels, logits): | |
return loss_fn(next_sentence_label, next_sentence_reduced_logits) | ||
|
||
|
||
def detect_tf_missing_unexpected_layers(model, resolved_archive_file): | ||
def load_tf_weights(model, resolved_archive_file): | ||
""" | ||
Detect missing and unexpected layers. | ||
Detect missing and unexpected layers | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring should be adapted since this function loads the TF weights and detects missing and unexpected layers. |
||
|
||
Args: | ||
model (:obj:`tf.keras.models.Model`): | ||
The model to load the weights into. | ||
resolved_archive_file (:obj:`str`): | ||
The location of the H5 file. | ||
The location of the H5 file | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to remove this dot. |
||
|
||
Returns: | ||
Two lists, one for the missing layers, and another one for the unexpected layers. | ||
|
@@ -253,60 +253,49 @@ def detect_tf_missing_unexpected_layers(model, resolved_archive_file): | |
unexpected_layers = [] | ||
|
||
with h5py.File(resolved_archive_file, "r") as f: | ||
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) | ||
model_layer_names = set(layer.name for layer in model.layers) | ||
missing_layers = list(model_layer_names - saved_layer_names) | ||
unexpected_layers = list(saved_layer_names - model_layer_names) | ||
saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) | ||
model_layers_name_value = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line and the for loop after should be replaced by
This is easier code to read IMO and will be faster (dict comprehension is faster than a manual loop). |
||
|
||
for layer in model.layers: | ||
if layer.name in saved_layer_names: | ||
g = f[layer.name] | ||
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") | ||
saved_weight_names_set = set( | ||
"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names | ||
) | ||
Comment on lines
-265
to
-267
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line and the line defining "/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names they will now be filled by "/".join(weight_name.split("/")[1:]) for weight_name in saved_weight_names Why the change from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I took a deeper look and understand the change here. I guess it's because the second index is the the prefix, and since in TF the main layer is named after the prefix, it will remain the same across base models and models with heads. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm really sorry @LysandreJik I completely forgot to answer :( And yes this is exactly for that :) |
||
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights | ||
symbolic_weights_names = set( | ||
"/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights | ||
) | ||
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) | ||
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) | ||
name = layer.name | ||
model_layers_name_value[name] = layer | ||
|
||
return missing_layers, unexpected_layers | ||
model_layers_name = set(model_layers_name_value.keys()) | ||
renamed_saved_h5_model_layers_names = set() | ||
|
||
for layer_name in saved_h5_model_layers_name: | ||
name = layer_name | ||
|
||
def load_tf_weights(model, resolved_archive_file): | ||
""" | ||
Load the TF weights from a H5 file. | ||
renamed_saved_h5_model_layers_names.add(name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't get this code. It loops through the set
but why do we need a new variable and why is it "renamed"? |
||
|
||
Args: | ||
model (:obj:`tf.keras.models.Model`): | ||
The model to load the weights into. | ||
resolved_archive_file (:obj:`str`): | ||
The location of the H5 file. | ||
""" | ||
with h5py.File(resolved_archive_file, "r") as f: | ||
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) | ||
missing_layers = list(model_layers_name - renamed_saved_h5_model_layers_names) | ||
unexpected_layers = list(renamed_saved_h5_model_layers_names - model_layers_name) | ||
saved_weight_names_set = set() | ||
symbolic_weights_names = set() | ||
weight_value_tuples = [] | ||
|
||
for layer in model.layers: | ||
if layer.name in saved_layer_names: | ||
g = f[layer.name] | ||
for layer_name in saved_h5_model_layers_name: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason to change the loop? AFAICT we only do things for objects that are in the intersection of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right, this variable was not really necessary and the previous loop could be kept as it was. I changed it. |
||
if layer_name in model_layers_name: | ||
g = f[layer_name] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better comment? |
||
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") | ||
layer = model_layers_name_value[layer_name] | ||
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights | ||
saved_weight_names_values = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's a nit, but I think it's cleaner to name this dict just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated! |
||
|
||
for weight_name in saved_weight_names: | ||
name = "/".join(weight_name.split("/")[1:]) | ||
saved_weight_names_values[name] = np.asarray(g[weight_name]) | ||
|
||
saved_weight_names_set.add(name) | ||
|
||
for symbolic_weight in symbolic_weights: | ||
splited_layers = symbolic_weight.name.split("/")[1:] | ||
symbolic_weight_name = "/".join(splited_layers) | ||
symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) | ||
saved_weight_value = None | ||
|
||
if symbolic_weight_name in saved_weight_names_values: | ||
saved_weight_value = saved_weight_names_values[symbolic_weight_name] | ||
|
||
if saved_weight_value is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why adding this test and the line before There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree here I think we can just put all the code below There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively we could do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like this? |
||
if K.int_shape(symbolic_weight) != saved_weight_value.shape: | ||
try: | ||
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) | ||
|
@@ -318,8 +307,15 @@ def load_tf_weights(model, resolved_archive_file): | |
|
||
weight_value_tuples.append((symbolic_weight, array)) | ||
|
||
symbolic_weights_names.add(symbolic_weight_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will add the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is expected. I need to have the name of all the layers from the checkpoint on one side, and from the instantiated model on the other side, in order to properly compute the missing and unexpected layers. |
||
|
||
K.batch_set_value(weight_value_tuples) | ||
|
||
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) | ||
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) | ||
|
||
return missing_layers, unexpected_layers | ||
|
||
|
||
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): | ||
r""" | ||
|
@@ -728,7 +724,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
# 'by_name' allow us to do transfer learning by skipping/adding layers | ||
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 | ||
try: | ||
load_tf_weights(model, resolved_archive_file) | ||
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file) | ||
except OSError: | ||
raise OSError( | ||
"Unable to load weights from h5 file. " | ||
|
@@ -737,8 +733,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
|
||
model(model.dummy_inputs, training=False) # Make sure restore ops are run | ||
|
||
missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file) | ||
|
||
if cls.authorized_missing_keys is not None: | ||
for pat in cls.authorized_missing_keys: | ||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None] | ||
|
@@ -1034,18 +1028,18 @@ def call(self, inputs, cls_index=None, training=False): | |
return output | ||
|
||
|
||
def shape_list(x: tf.Tensor) -> List[int]: | ||
def shape_list(tensor: tf.Tensor) -> List[int]: | ||
""" | ||
Deal with dynamic shape in tensorflow cleanly. | ||
|
||
Args: | ||
x (:obj:`tf.Tensor`): The tensor we want the shape of. | ||
tensor (:obj:`tf.Tensor`): The tensor we want the shape of. | ||
|
||
Returns: | ||
:obj:`List[int]`: The shape of the tensor as a list. | ||
""" | ||
static = x.shape.as_list() | ||
dynamic = tf.shape(x) | ||
static = tensor.shape.as_list() | ||
dynamic = tf.shape(tensor) | ||
return [dynamic[i] if s is None else s for i, s in enumerate(static)] | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for the
_, _ =
I think