Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New TF embeddings (cleaner and faster) #9418

Merged
merged 24 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 137 additions & 19 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,25 +809,30 @@ def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:

return model_embeds

def _get_word_embedding_weight(self, embedding_layer):
if hasattr(embedding_layer, "word_embeddings"):
return embedding_layer.word_embeddings
elif hasattr(embedding_layer, "weight"):
return embedding_layer.weight
elif hasattr(embedding_layer, "decoder"):
return embedding_layer.decoder
else:
# Here we build the word embeddings weights if not exists.
# And then we retry to get the attribute once built.
self(self.dummy_inputs)
if hasattr(embedding_layer, "word_embeddings"):
return embedding_layer.word_embeddings
elif hasattr(embedding_layer, "weight"):
return embedding_layer.weight
elif hasattr(embedding_layer, "decoder"):
return embedding_layer.decoder
else:
return None
def _get_word_embedding_weight(model, embedding_layer):
embeds = getattr(embedding_layer, "weight", None)

if embeds is not None:
return embeds

embeds = getattr(embedding_layer, "decoder", None)

if embeds is not None:
return embeds

model(model.dummy_inputs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here to say we retry after building the model, just in case it was not already?


embeds = getattr(embedding_layer, "weight", None)

if embeds is not None:
return embeds

embeds = getattr(embedding_layer, "decoder", None)

if embeds is not None:
return embeds

return None

def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings())
Expand Down Expand Up @@ -1319,6 +1324,119 @@ def call(self, x):
return x


class WordEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
self.word_embeddings = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)

super().build(input_shape=input_shape)

def get_config(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this function needed for?

Copy link
Contributor Author

@jplu jplu Jan 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a required function when a layer takes some parameters in its __init__ to become serializable, see more detail in the doc https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer#get_config and https://www.tensorflow.org/guide/keras/custom_layers_and_models#you_can_optionally_enable_serialization_on_your_layers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically this is what does the @keras_serializable

config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))

def call(self, input_ids):
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those are multiple operation that replaced a single matrix operation no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YEs

embeddings = tf.gather(params=self.word_embeddings, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])

return embeddings


class TokenTypeEmbeddings(tf.keras.layers.Layer):
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)

self.type_vocab_size = type_vocab_size
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)

super().build(input_shape=input_shape)

def get_config(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the function required?

config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))

def call(self, token_type_ids):
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])

return embeddings


class PositionEmbeddings(tf.keras.layers.Layer):
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
super().__init__(**kwargs)

self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
)

super().build(input_shape)

def get_config(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same why do we need this function?

config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
"initializer_range": self.initializer_range,
}
base_config = super().get_config()

return dict(list(base_config.items()) + list(config.items()))

def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]

return tf.broadcast_to(input=position_embeddings, shape=input_shape)


class TFSharedEmbeddings(tf.keras.layers.Layer):
r"""
Construct shared token embeddings.
Expand Down
Loading