Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New TF embeddings (cleaner and faster) #9418

Merged
merged 24 commits into from
Jan 20, 2021
Merged

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Jan 5, 2021

What does this PR do?

This PR propose a better implementation of the embedding layer for the BERT-Like TF models. Another benefit of this cleaning is a better computational performance:

model = TFBertForMaskedLM.from_pretrained("bert-base-cased")
cProfile.run("model(model.dummy_inputs)")

# current master
56150 function calls (55318 primitive calls) in 0.096 seconds

# with new embeddings implem
55732 function calls (54891 primitive calls) in 0.080 seconds

This new implementation should be compatible with the incoming rework of the resizing proposed in #9193. A similar work will be applied to TFSharedEmbeddings in a next PR.

All slow/quick tests passes.

EDIT: I don't know why Github has some issues to pin the reviewers, so pinging @LysandreJik @sgugger and @patrickvonplaten

@jplu jplu requested review from LysandreJik, sgugger and patrickvonplaten and removed request for LysandreJik, sgugger and patrickvonplaten January 5, 2021 13:27

super().build(input_shape=input_shape)

def get_config(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this function needed for?

Copy link
Contributor Author

@jplu jplu Jan 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a required function when a layer takes some parameters in its __init__ to become serializable, see more detail in the doc https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#get_config and https://www.tensorflow.org/guide/keras/custom_layers_and_models#you_can_optionally_enable_serialization_on_your_layers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically this is what does the @keras_serializable


super().build(input_shape)

def get_config(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same why do we need this function?


super().build(input_shape=input_shape)

def get_config(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the function required?

name="token_type_embeddings",
)
self.embeddings = tf.keras.layers.Add()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? This layer is just an "add" operation no? Why is it called embeddings?

Copy link
Contributor Author

@jplu jplu Jan 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the optimized version of doing tensor+tensor+.. the other advantage to use this layer (other than computational perf) is that it handles some checking over the given tensors such as a proper shape.

I named it embeddings because it represents the addition of all the embeddings.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name could be clearer I think: embeddings_sum is more explicit.

@@ -501,10 +448,10 @@ def __init__(self, config, add_pooling_layer=True, **kwargs):
)

def get_input_embeddings(self):
return self.embeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this still return the same type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Still a tf.keras.layers.Layer object.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how it can be used above (line 420) in a tf.matmul if it's a layer and not a weight.

@patrickvonplaten
Copy link
Contributor

I like this PR in general!

Just wondering about two things:

  1. Do we need this get_config function?
  2. Not a huge fan of the Add() keras layer...does this really improve performance much?

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick look you're factoring the embeddings computation in three classes that will live in modeling_tf_utils.py.

Usually we try to be as explicit as possible and display every operation in a single file, while here we're applying different embeddings operations in another file. I think this goes against our "everything in one file" principle.

Is there a good reason for the embeddings to be the exception to this rule? Personally I think I would like to see directly in the file that the embeddings are computed differently according to the matrix sizes, but putting these layers in the modeling_tf_utils.py makes it abstracted/hidden.

@jplu
Copy link
Contributor Author

jplu commented Jan 5, 2021

Good point @LysandreJik! Basically here most of the models share the similar embedding computation that stay inside their respective file. What has been exported is just the specific computation, which means that WordEmbeddings, PositionalEmbeddings and TokenTypeEmbeddings are always the same doesn't matter who is using it.

The same logic that is currently applied to TFSharedEmbeddings.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just reviewed the general approach on one model for now and I have some questions before going further. If I understand correctly, the computation of the three different types of embeddings is split in three different ways to maximize the speedup but I wonder if it's documented from TF or just some tests on one particular setup. Before adding the extra complexity, I would like to be sure it brings a speedup on almost all possible environments (CPU, GPU, multi-GPU, TPU) without any loss in memory footprint (one-hot encoding the token type ids seems harmless, but we never know).

As for putting those in modeling utils versus the model file, I agree with Lysandre that this breaks our philosophy of putting everything in each model file. I emitted the same reserves for TFSharedEmbeddings when it was introduced.

name="token_type_embeddings",
)
self.embeddings = tf.keras.layers.Add()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name could be clearer I think: embeddings_sum is more explicit.

@@ -501,10 +448,10 @@ def __init__(self, config, add_pooling_layer=True, **kwargs):
)

def get_input_embeddings(self):
return self.embeddings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand how it can be used above (line 420) in a tf.matmul if it's a layer and not a weight.

@jplu
Copy link
Contributor Author

jplu commented Jan 5, 2021

Just reviewed the general approach on one model for now and I have some questions before going further. If I understand correctly, the computation of the three different types of embeddings is split in three different ways to maximize the speedup but I wonder if it's documented from TF or just some tests on one particular setup. Before adding the extra complexity, I would like to be sure it brings a speedup on almost all possible environments (CPU, GPU, multi-GPU, TPU) without any loss in memory footprint (one-hot encoding the token type ids seems harmless, but we never know).

I basically took example on the official implementation of Transformer encoder available in the Google Repo https://github.com/tensorflow/models/tree/master/official/nlp/keras_nlp . After having done several experiments (only on CPU and GPU though), I end up to extract from this an optimal version for each embedding.

As for putting those in modeling utils versus the model file, I agree with Lysandre that this breaks our philosophy of putting everything in each model file. I emitted the same reserves for TFSharedEmbeddings when it was introduced.

I don't mind to copy/paste the same layers in all the concerned files if it is the recommended way. @sgugger @LysandreJik Will you be more confident if I create a version for each model and add the comment # copied from .... everytime it is a strong copy/paste?

I don't understand how it can be used above (line 420) in a tf.matmul if it's a layer and not a weight.

Now the get_input_embeddings returns a WordEmbeddings layer that has a word_embeddings attribute. If you look at the Bert model for example, the layer TFBertLMPredictionHead takes a WordEmbeddings layer as input_embeddings and use the WordEmbeddings.word_embeddings attribute into the tf.matmul.

@sgugger
Copy link
Collaborator

sgugger commented Jan 7, 2021

Now the get_input_embeddings returns a WordEmbeddings layer that has a word_embeddings attribute. If you look at the Bert model for example, the layer TFBertLMPredictionHead takes a WordEmbeddings layer as input_embeddings and use the WordEmbeddings.word_embeddings attribute into the tf.matmul.

So this part confuses me. Why name word_embeddings the weights inside the WordEmbeddings? It causes so much headache when reading the code afterward as we keep seeing some word_embeddings attributes which might either be an embedding layer or a weight.

Also, how does the new organization not screw up pretrained weights? From what I understand, the old world_embeddings in the BertEmbeddings layer used to be a weight and now it's a layer with an added world_embeddings attribute?

@jplu
Copy link
Contributor Author

jplu commented Jan 8, 2021

So this part confuses me. Why name word_embeddings the weights inside the WordEmbeddings? It causes so much headache when reading the code afterward as we keep seeing some word_embeddings attributes which might either be an embedding layer or a weight.

I agree it is confusing, if you prefer it can be called weight such as in TFSharedEmbeddings I think it would be a more suitable name. This renaming will make easier the kind of checking (from the incoming PR on ebd resizing)

 def _get_word_embedding_weight(self, embedding_layer):
        if hasattr(embedding_layer, "word_embeddings"):
            return embedding_layer.word_embeddings
        elif hasattr(embedding_layer, "weight"):
            return embedding_layer.weight
        elif hasattr(embedding_layer, "decoder"):
            return embedding_layer.decoder
        else:
            # Here we build the word embeddings weights if not exists.
            # And then we retry to get the attribute once built.
            self(self.dummy_inputs)
            if hasattr(embedding_layer, "word_embeddings"):
                return embedding_layer.word_embeddings
            elif hasattr(embedding_layer, "weight"):
                return embedding_layer.weight
            elif hasattr(embedding_layer, "decoder"):
                return embedding_layer.decoder
            else:
                return None

No more word_embeddings or weight, only weight. What do you think?

Also, how does the new organization not screw up pretrained weights? From what I understand, the old world_embeddings in the BertEmbeddings layer used to be a weight and now it's a layer with an added world_embeddings attribute?

This is because before we where using a name score and not anymore in this PR. Let's say that defining a name scope or creating a layer represents the same thing. In both cases the weight is named 'tf_bert_model/bert/embeddings/word_embeddings/weight:0' until now the word_embeddings part of the naming was because the embeddings was created in the context of tf.name_scope("word_embeddings"): , in this PR it has the same name but because of the name of the new WordEmbeddings layer.

@sgugger
Copy link
Collaborator

sgugger commented Jan 8, 2021

Yes, having only "weight" makes more sense to me, and it would make the code easier to read. Thanks for explaining why the name of the weight doesn't change for loading!

@jplu
Copy link
Contributor Author

jplu commented Jan 11, 2021

I found another advantage of these new embedding computation. It allows our models to be compiled in XLA_GPU and XLA_TPU which was not the case before. Small proof test on a machine with a GPU:

from transformers import TFBertModel
import tensorflow as tf

model = TFBertModel.from_pretrained("bert-base-cased")

@tf.function(experimental_compile=True)
def run():
    return model(model.dummy_inputs)

outputs = run()

On master fails with:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Trying to access resource _AnonymousVar4 located in device /job:localhost/replica:0/task:0/device:CPU:0 from device /job:localhost/replica:0/task:0/device:GPU:0 [Op:__inference_run_4637]

On this PR works as expected. The reason is because the tf.keras.layers.Embeddings layers are initialized when the model is instanciated instead of being initialized at build time.

@jplu jplu force-pushed the new-tf-embeddings branch from 7a1c5b1 to c1cb284 Compare January 11, 2021 22:41
@jplu jplu force-pushed the new-tf-embeddings branch from f5981f0 to 0eeac1e Compare January 12, 2021 10:31
@jplu
Copy link
Contributor Author

jplu commented Jan 12, 2021

Now, each model has its own WordEmbedding, TokenTypeEmbeddings and PositionEmbedding layer in the model file decorated with the comment #Copied from... and the words_embeddings weights have been renamed into weight to make it more understandable and aligned with the name in TFSharedEmbeddings.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the modifications! This looks way better now, I think.

if embeds is not None:
return embeds

model(model.dummy_inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here to say we retry after building the model, just in case it was not already?

@@ -118,7 +234,7 @@ def call(self, hidden_states, attention_mask=None, head_mask=None, output_attent
attention_scores = tf.einsum("aecd,abcd->acbe", key_layer, query_layer)

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in TFElectraModel call() function)
# Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be replaced, see comment above.

@@ -536,96 +652,41 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):

Returns: tf.Tensor
"""
seq_length = shape_list(inputs_embeds)[1]
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
bsz, seq_length = shape_list(tensor=inputs_embeds)[:2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bsz is a bit too short IMO, batch_size should be used (here and two lines below).

def call(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand MPNet's implementation has some token_type_ids it doesn't use, but I'd leave them for now here until there is a general fix (that also deals with the PyTorch implementation). The tokenizer still return those token_type_ids so this would cause problem if a user feeds the output of a tokenizer to one of those models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

token_type_ids is not in the PyTorch implementation, so I think the Tokenizer should be fixed in same time than the TF model.

@@ -132,96 +249,41 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds):

Returns: tf.Tensor
"""
seq_length = shape_list(inputs_embeds)[1]
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
bsz, seq_length = shape_list(tensor=inputs_embeds)[:2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as before: bsz -> batch_size

Comment on lines 810 to 832
embeds = getattr(embedding_layer, "weight", None)

if embeds is not None:
return embeds

embeds = getattr(embedding_layer, "decoder", None)

if embeds is not None:
return embeds

model(model.dummy_inputs)

embeds = getattr(embedding_layer, "weight", None)

if embeds is not None:
return embeds

embeds = getattr(embedding_layer, "decoder", None)

if embeds is not None:
return embeds

return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments as for modeling_utils (plus what are we testing if we just use the same code?)

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general. One thing I'm not 100% sure about is whether we really need to add keras layers like tf.keras.layers.Add() if we start doing this for the embeddings now, I'm wondering if we should do the same for all residual connections in the self-attention blocks

@jplu
Copy link
Contributor Author

jplu commented Jan 19, 2021

LGTM in general. One thing I'm not 100% sure about is whether we really need to add keras layers like tf.keras.layers.Add() if we start doing this for the embeddings now, I'm wondering if we should do the same for all residual connections in the self-attention blocks

In the absolute, yes we should. In an ideal world, everytime TF proposes a function/layer for doing something we should use it, as it is part of the optimization process. I know and I understand that it might seems confusing and starts to diverge with what PT looks like.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this LGTM. I also agree with your explanations regarding the Add layers.

@jplu jplu merged commit 14042d5 into huggingface:master Jan 20, 2021
@jplu jplu deleted the new-tf-embeddings branch January 20, 2021 11:21
"""
if mode == "embedding":
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was a single matrix multiplication before no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

return dict(list(base_config.items()) + list(config.items()))

def call(self, input_ids):
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those are multiple operation that replaced a single matrix operation no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YEs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants