diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 2de2b1f0eecb..324f350f85db 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -236,9 +236,9 @@ def compute_loss(self, labels, logits): return loss_fn(next_sentence_label, next_sentence_reduced_logits) -def detect_tf_missing_unexpected_layers(model, resolved_archive_file): +def load_tf_weights(model, resolved_archive_file): """ - Detect missing and unexpected layers. + Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes. Args: model (:obj:`tf.keras.models.Model`): @@ -252,62 +252,60 @@ def detect_tf_missing_unexpected_layers(model, resolved_archive_file): missing_layers = [] unexpected_layers = [] + # Read the H5 file with h5py.File(resolved_archive_file, "r") as f: - saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) - model_layer_names = set(layer.name for layer in model.layers) - missing_layers = list(model_layer_names - saved_layer_names) - unexpected_layers = list(saved_layer_names - model_layer_names) - - for layer in model.layers: - if layer.name in saved_layer_names: - g = f[layer.name] - saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") - saved_weight_names_set = set( - "/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names - ) - symbolic_weights = layer.trainable_weights + layer.non_trainable_weights - symbolic_weights_names = set( - "/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights - ) - missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) - unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) + # Retrieve the name of each layer from the H5 file + saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) - return missing_layers, unexpected_layers - - -def load_tf_weights(model, resolved_archive_file): - """ - Load the TF weights from a H5 file. + # Find the missing layers from the high level list of layers + missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name) - Args: - model (:obj:`tf.keras.models.Model`): - The model to load the weights into. - resolved_archive_file (:obj:`str`): - The location of the H5 file. - """ - with h5py.File(resolved_archive_file, "r") as f: - saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) + # Find the unexpected layers from the high level list of layers + unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers])) + saved_weight_names_set = set() + symbolic_weights_names = set() weight_value_tuples = [] + # Compute missing and unexpected sub layers + # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] for layer in model.layers: - if layer.name in saved_layer_names: - g = f[layer.name] - saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") + # if layer_name from the H5 file belongs to the layers from the instantiated model + if layer.name in saved_h5_model_layers_name: + # Get the H5 layer object from its name + h5_layer_object = f[layer.name] + # Get all the weights as a list from the layer object symbolic_weights = layer.trainable_weights + layer.non_trainable_weights - saved_weight_names_values = {} + saved_weights = {} - for weight_name in saved_weight_names: + # Create a dict from the H5 saved model that looks like {"weight_name": weight_value} + # And a set with only the names + for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"): + # TF names always start with the model name so we ignore it name = "/".join(weight_name.split("/")[1:]) - saved_weight_names_values[name] = np.asarray(g[weight_name]) + saved_weights[name] = np.asarray(h5_layer_object[weight_name]) + # Add the updated name to the final list for computing missing/unexpected values + saved_weight_names_set.add(name) + + # Loop over each weights from the instantiated model and compare with the weights from the H5 file for symbolic_weight in symbolic_weights: - splited_layers = symbolic_weight.name.split("/")[1:] - symbolic_weight_name = "/".join(splited_layers) + # TF names always start with the model name so we ignore it + symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) + + # here we check if the current weight is among the weights from the H5 file + # If yes, get the weight_value of the corresponding weight from the H5 file + # If not, make the value to None + saved_weight_value = saved_weights.get(symbolic_weight_name, None) - if symbolic_weight_name in saved_weight_names_values: - saved_weight_value = saved_weight_names_values[symbolic_weight_name] + # Add the updated name to the final list for computing missing/unexpected values + symbolic_weights_names.add(symbolic_weight_name) + # If the current weight is found + if saved_weight_value is not None: + # Check if the shape of the current weight and the one from the H5 file are different if K.int_shape(symbolic_weight) != saved_weight_value.shape: + # If yes we reshape the weight from the H5 file accordingly to the current weight + # If the two shapes are not compatible we raise an issue try: array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) except AssertionError as e: @@ -316,10 +314,18 @@ def load_tf_weights(model, resolved_archive_file): else: array = saved_weight_value + # We create the tuple that will be loaded and add it to the final list weight_value_tuples.append((symbolic_weight, array)) + # Load all the weights K.batch_set_value(weight_value_tuples) + # Compute the missing and unexpected layers + missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) + unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) + + return missing_layers, unexpected_layers + class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): r""" @@ -728,7 +734,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: - load_tf_weights(model, resolved_archive_file) + missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file) except OSError: raise OSError( "Unable to load weights from h5 file. " @@ -737,8 +743,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): model(model.dummy_inputs, training=False) # Make sure restore ops are run - missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file) - if cls.authorized_missing_keys is not None: for pat in cls.authorized_missing_keys: missing_keys = [k for k in missing_keys if re.search(pat, k) is None] @@ -1034,18 +1038,18 @@ def call(self, inputs, cls_index=None, training=False): return output -def shape_list(x: tf.Tensor) -> List[int]: +def shape_list(tensor: tf.Tensor) -> List[int]: """ Deal with dynamic shape in tensorflow cleanly. Args: - x (:obj:`tf.Tensor`): The tensor we want the shape of. + tensor (:obj:`tf.Tensor`): The tensor we want the shape of. Returns: :obj:`List[int]`: The shape of the tensor as a list. """ - static = x.shape.as_list() - dynamic = tf.shape(x) + static = tensor.shape.as_list() + dynamic = tf.shape(tensor) return [dynamic[i] if s is None else s for i, s in enumerate(static)]