Skip to content

Commit

Permalink
New TF loading weights (#8490)
Browse files Browse the repository at this point in the history
* New TF loading weights

* apply style

* Better naming

* Largely comment the loading method

* Apply style

* Address Patrick's comments

* Remove useless line of code

* Update Docstring

* Address Sylvain's and Lysandre's comments

* Simplify the names computation

* Typos
  • Loading branch information
jplu authored Nov 18, 2020
1 parent 0df91ee commit 3bc1540
Showing 1 changed file with 55 additions and 51 deletions.
106 changes: 55 additions & 51 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ def compute_loss(self, labels, logits):
return loss_fn(next_sentence_label, next_sentence_reduced_logits)


def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
def load_tf_weights(model, resolved_archive_file):
"""
Detect missing and unexpected layers.
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
Args:
model (:obj:`tf.keras.models.Model`):
Expand All @@ -252,62 +252,60 @@ def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
missing_layers = []
unexpected_layers = []

# Read the H5 file
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_layers = list(model_layer_names - saved_layer_names)
unexpected_layers = list(saved_layer_names - model_layer_names)

for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
saved_weight_names_set = set(
"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
)
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
symbolic_weights_names = set(
"/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
)
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
# Retrieve the name of each layer from the H5 file
saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))

return missing_layers, unexpected_layers


def load_tf_weights(model, resolved_archive_file):
"""
Load the TF weights from a H5 file.
# Find the missing layers from the high level list of layers
missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)

Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
"""
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
# Find the unexpected layers from the high level list of layers
unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers]))
saved_weight_names_set = set()
symbolic_weights_names = set()
weight_value_tuples = []

# Compute missing and unexpected sub layers
# Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
# if layer_name from the H5 file belongs to the layers from the instantiated model
if layer.name in saved_h5_model_layers_name:
# Get the H5 layer object from its name
h5_layer_object = f[layer.name]
# Get all the weights as a list from the layer object
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
saved_weight_names_values = {}
saved_weights = {}

for weight_name in saved_weight_names:
# Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
# And a set with only the names
for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
# TF names always start with the model name so we ignore it
name = "/".join(weight_name.split("/")[1:])
saved_weight_names_values[name] = np.asarray(g[weight_name])
saved_weights[name] = np.asarray(h5_layer_object[weight_name])

# Add the updated name to the final list for computing missing/unexpected values
saved_weight_names_set.add(name)

# Loop over each weights from the instantiated model and compare with the weights from the H5 file
for symbolic_weight in symbolic_weights:
splited_layers = symbolic_weight.name.split("/")[1:]
symbolic_weight_name = "/".join(splited_layers)
# TF names always start with the model name so we ignore it
symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])

# here we check if the current weight is among the weights from the H5 file
# If yes, get the weight_value of the corresponding weight from the H5 file
# If not, make the value to None
saved_weight_value = saved_weights.get(symbolic_weight_name, None)

if symbolic_weight_name in saved_weight_names_values:
saved_weight_value = saved_weight_names_values[symbolic_weight_name]
# Add the updated name to the final list for computing missing/unexpected values
symbolic_weights_names.add(symbolic_weight_name)

# If the current weight is found
if saved_weight_value is not None:
# Check if the shape of the current weight and the one from the H5 file are different
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
# If yes we reshape the weight from the H5 file accordingly to the current weight
# If the two shapes are not compatible we raise an issue
try:
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except AssertionError as e:
Expand All @@ -316,10 +314,18 @@ def load_tf_weights(model, resolved_archive_file):
else:
array = saved_weight_value

# We create the tuple that will be loaded and add it to the final list
weight_value_tuples.append((symbolic_weight, array))

# Load all the weights
K.batch_set_value(weight_value_tuples)

# Compute the missing and unexpected layers
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))

return missing_layers, unexpected_layers


class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r"""
Expand Down Expand Up @@ -727,7 +733,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try:
load_tf_weights(model, resolved_archive_file)
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file)
except OSError:
raise OSError(
"Unable to load weights from h5 file. "
Expand All @@ -736,8 +742,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

model(model.dummy_inputs, training=False) # Make sure restore ops are run

missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)

if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
Expand Down Expand Up @@ -1033,18 +1037,18 @@ def call(self, inputs, cls_index=None, training=False):
return output


def shape_list(x: tf.Tensor) -> List[int]:
def shape_list(tensor: tf.Tensor) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
x (:obj:`tf.Tensor`): The tensor we want the shape of.
tensor (:obj:`tf.Tensor`): The tensor we want the shape of.
Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
static = x.shape.as_list()
dynamic = tf.shape(x)
static = tensor.shape.as_list()
dynamic = tf.shape(tensor)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]


Expand Down

0 comments on commit 3bc1540

Please sign in to comment.