Skip to content

Commit

Permalink
New TF embeddings (cleaner and faster) (#9418)
Browse files Browse the repository at this point in the history
* Create new embeddings + add to BERT

* Add Albert

* Add DistilBert

* Add Albert + Electra + Funnel

* Add Longformer + Lxmert

* Add last models

* Apply style

* Update the template

* Remove unused imports

* Rename attribute

* Import embeddings in their own model file

* Replace word_embeddings per weight

* fix naming

* Fix Albert

* Fix Albert

* Fix Longformer

* Fix Lxmert Mobilebert and MPNet

* Fix copy

* Fix template

* Update the get weights function

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <[email protected]>

* Update src/transformers/models/electra/modeling_tf_electra.py

Co-authored-by: Sylvain Gugger <[email protected]>

* address Sylvain's comments

Co-authored-by: Sylvain Gugger <[email protected]>
  • Loading branch information
jplu and sgugger authored Jan 20, 2021
1 parent 12f0d7e commit 14042d5
Show file tree
Hide file tree
Showing 13 changed files with 1,829 additions and 1,188 deletions.
155 changes: 136 additions & 19 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,25 +809,29 @@ def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:

return model_embeds

def _get_word_embedding_weight(self, embedding_layer):
if hasattr(embedding_layer, "word_embeddings"):
return embedding_layer.word_embeddings
elif hasattr(embedding_layer, "weight"):
return embedding_layer.weight
elif hasattr(embedding_layer, "decoder"):
return embedding_layer.decoder
else:
# Here we build the word embeddings weights if not exists.
# And then we retry to get the attribute once built.
self(self.dummy_inputs)
if hasattr(embedding_layer, "word_embeddings"):
return embedding_layer.word_embeddings
elif hasattr(embedding_layer, "weight"):
return embedding_layer.weight
elif hasattr(embedding_layer, "decoder"):
return embedding_layer.decoder
else:
return None
def _get_word_embedding_weight(model, embedding_layer):
embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
return embeds

embeds = getattr(embedding_layer, "decoder", None)
if embeds is not None:
return embeds

# The reason why the attributes don't exist might be
# because the model is not built, so retry getting
# the argument after building the model
model(model.dummy_inputs)

embeds = getattr(embedding_layer, "weight", None)
if embeds is not None:
return embeds

embeds = getattr(embedding_layer, "decoder", None)
if embeds is not None:
return embeds

return None

def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
Expand Down Expand Up @@ -1319,6 +1323,119 @@ def call(self, x):
return x


class WordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
self.word_embeddings = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)

super().build(input_shape=input_shape)

def get_config(self):
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))

def call(self, input_ids):
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.word_embeddings, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])

return embeddings


class TokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)

self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)

super().build(input_shape=input_shape)

def get_config(self):
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))

def call(self, token_type_ids):
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])

return embeddings


class PositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)

self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)

super().build(input_shape)

def get_config(self):
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))

def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]

return tf.broadcast_to(input=position_embeddings, shape=input_shape)


class TFSharedEmbeddings(tf.keras.layers.Layer):
r"""
Construct shared token embeddings.
Expand Down
Loading

0 comments on commit 14042d5

Please sign in to comment.