Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New TF loading weights #8490

Merged
merged 11 commits into from
Nov 18, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path, tf_inputs
if tf_inputs is not None:
tf_model(tf_inputs, training=False) # Make sure model is built

load_tf_weights(tf_model, tf_checkpoint_path)
_, _ = load_tf_weights(tf_model, tf_checkpoint_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for the _, _ = I think


return load_tf2_model_in_pytorch_model(pt_model, tf_model, allow_missing_keys=allow_missing_keys)

Expand Down
80 changes: 37 additions & 43 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,15 @@ def compute_loss(self, labels, logits):
return loss_fn(next_sentence_label, next_sentence_reduced_logits)


def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
def load_tf_weights(model, resolved_archive_file):
"""
Detect missing and unexpected layers.
Detect missing and unexpected layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring should be adapted since this function loads the TF weights and detects missing and unexpected layers.


Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
The location of the H5 file
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to remove this dot.


Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
Expand All @@ -253,60 +253,49 @@ def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
unexpected_layers = []

with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_layers = list(model_layer_names - saved_layer_names)
unexpected_layers = list(saved_layer_names - model_layer_names)
saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layers_name_value = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line and the for loop after should be replaced by

model_layers_name_value = {layer.name: layer for layer in model.layers}

This is easier code to read IMO and will be faster (dict comprehension is faster than a manual loop).


for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
saved_weight_names_set = set(
"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
)
Comment on lines -265 to -267
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line and the line defining symbolic_weights_names were removed. Instead of being filled by

"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names

they will now be filled by

"/".join(weight_name.split("/")[1:]) for weight_name in saved_weight_names

Why the change from 2: to 1:?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a deeper look and understand the change here. I guess it's because the second index is the the prefix, and since in TF the main layer is named after the prefix, it will remain the same across base models and models with heads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really sorry @LysandreJik I completely forgot to answer :( And yes this is exactly for that :)

symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
symbolic_weights_names = set(
"/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
)
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
name = layer.name
model_layers_name_value[name] = layer

return missing_layers, unexpected_layers
model_layers_name = set(model_layers_name_value.keys())
renamed_saved_h5_model_layers_names = set()

for layer_name in saved_h5_model_layers_name:
name = layer_name

def load_tf_weights(model, resolved_archive_file):
"""
Load the TF weights from a H5 file.
renamed_saved_h5_model_layers_names.add(name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't get this code. It loops through the set saved_h5_model_layers_name and add each element to a new empty set. So this is exactly the same as doing

renamed_saved_h5_model_layers_names = saved_h5_model_layers_name.copy()

but why do we need a new variable and why is it "renamed"?


Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
"""
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
missing_layers = list(model_layers_name - renamed_saved_h5_model_layers_names)
unexpected_layers = list(renamed_saved_h5_model_layers_names - model_layers_name)
saved_weight_names_set = set()
symbolic_weights_names = set()
weight_value_tuples = []

for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
for layer_name in saved_h5_model_layers_name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to change the loop? AFAICT we only do things for objects that are in the intersection of saved_h5_model_layer_names (previously saved_layer_names) and the model_layer_names. However, the previous loop did not require the dictionary model_layers_name_value and this dictionary is only used here, so it seems we only gain complexity with this change.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, this variable was not really necessary and the previous loop could be kept as it was. I changed it.

if layer_name in model_layers_name:
g = f[layer_name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is g?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better comment?

saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
layer = model_layers_name_value[layer_name]
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
saved_weight_names_values = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a nit, but I think it's cleaner to name this dict just saved_weights. saved_weight_names_values makes me think of a list of values not a key/value dict...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated!


for weight_name in saved_weight_names:
name = "/".join(weight_name.split("/")[1:])
saved_weight_names_values[name] = np.asarray(g[weight_name])

saved_weight_names_set.add(name)

for symbolic_weight in symbolic_weights:
splited_layers = symbolic_weight.name.split("/")[1:]
symbolic_weight_name = "/".join(splited_layers)
symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])
saved_weight_value = None

if symbolic_weight_name in saved_weight_names_values:
saved_weight_value = saved_weight_names_values[symbolic_weight_name]

if saved_weight_value is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding this test and the line before saved_weight_value = None? There is no other place saved_weight_value is used so this just adds extra complexity in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree here I think we can just put all the code below if saved_weight_value is not None: directly under if symbolic_weight_name in saved_weight_names_values: no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we could do saved_weight_value = saved_weight_value.get(symbolic_weight_name, None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this?

if K.int_shape(symbolic_weight) != saved_weight_value.shape:
try:
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
Expand All @@ -318,8 +307,15 @@ def load_tf_weights(model, resolved_archive_file):

weight_value_tuples.append((symbolic_weight, array))

symbolic_weights_names.add(symbolic_weight_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will add the symbolic_weight_name to symbolic_weights_names even if saved_weight_value is None. Is that expected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is expected. I need to have the name of all the layers from the checkpoint on one side, and from the instantiated model on the other side, in order to properly compute the missing and unexpected layers.


K.batch_set_value(weight_value_tuples)

missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))

return missing_layers, unexpected_layers


class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r"""
Expand Down Expand Up @@ -728,7 +724,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try:
load_tf_weights(model, resolved_archive_file)
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file)
except OSError:
raise OSError(
"Unable to load weights from h5 file. "
Expand All @@ -737,8 +733,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

model(model.dummy_inputs, training=False) # Make sure restore ops are run

missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)

if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
Expand Down Expand Up @@ -1034,18 +1028,18 @@ def call(self, inputs, cls_index=None, training=False):
return output


def shape_list(x: tf.Tensor) -> List[int]:
def shape_list(tensor: tf.Tensor) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.

Args:
x (:obj:`tf.Tensor`): The tensor we want the shape of.
tensor (:obj:`tf.Tensor`): The tensor we want the shape of.

Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
static = x.shape.as_list()
dynamic = tf.shape(x)
static = tensor.shape.as_list()
dynamic = tf.shape(tensor)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]


Expand Down