Skip to content

Commit

Permalink
Full rework of the TF input/output embeddings and bias resizing (#9193)
Browse files Browse the repository at this point in the history
* Start rework resizing

* Rework bias/decoder resizing

* Full resizing rework

* Full resizing rework

* Start to update the models with the new approach

* Finish to update the models

* Update all the tests

* Update the template

* Fix tests

* Fix tests

* Test a new approach

* Refactoring

* Refactoring

* Refactoring

* New rework

* Rework BART

* Rework bert+blenderbot

* Rework CTRL

* Rework Distilbert

* Rework DPR

* Rework Electra

* Rework Flaubert

* Rework Funnel

* Rework GPT2

* Rework Longformer

* Rework Lxmert

* Rework marian+mbart

* Rework mobilebert

* Rework mpnet

* Rework openai

* Rework pegasus

* Rework Roberta

* Rework T5

* Rework xlm+xlnet

* Rework template

* Fix TFT5EncoderOnly + DPRs

* Restore previous methods

* Fix Funnel

* Fix CTRL and TransforXL

* Apply style

* Apply Sylvain's comments

* Restore a test in DPR

* Address the comments

* Fix bug

* Apply style

* remove unused import

* Fix test

* Forgot a method

* missing test

* Trigger CI

* naming update

* Rebase

* Trigger CI
  • Loading branch information
jplu authored Jan 11, 2021
1 parent cf41676 commit 1243ee7
Show file tree
Hide file tree
Showing 40 changed files with 1,475 additions and 595 deletions.
390 changes: 270 additions & 120 deletions src/transformers/modeling_tf_utils.py

Large diffs are not rendered by default.

80 changes: 20 additions & 60 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,21 @@ def build(self, input_shape):

super().build(input_shape)

def get_output_embeddings(self):
return self.decoder

def set_output_embeddings(self, value):
self.decoder.word_embeddings = value
self.decoder.vocab_size = shape_list(value)[0]

def get_bias(self):
return {"bias": self.bias, "decoder_bias": self.decoder_bias}

def set_bias(self, value):
self.bias = value["bias"]
self.decoder_bias = value["decoder_bias"]
self.vocab_size = shape_list(value["bias"])[0]

def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.activation(hidden_states)
Expand Down Expand Up @@ -505,10 +520,7 @@ def get_input_embeddings(self):

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]

def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
self.embeddings.vocab_size = shape_list(value)[0]

def _prune_heads(self, heads_to_prune):
"""
Expand Down Expand Up @@ -835,34 +847,8 @@ def __init__(self, config, *inputs, **kwargs):
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")
self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier")

def get_output_embeddings(self):
return self.albert.embeddings

def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)

# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)

init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)

self.predictions.decoder_bias.assign(init_decoder_bias)
def get_lm_head(self):
return self.predictions

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -980,34 +966,8 @@ def __init__(self, config, *inputs, **kwargs):
self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert")
self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions")

def get_output_embeddings(self):
return self.albert.embeddings

def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)

# ALBERT is a special case where there are two bias to update
# even though self.bias is not used anywhere and is here
# just to make the loading weights from a PT model happy
if new_num_tokens is not None:
num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens)
self.predictions.vocab_size = num_tokens_to_copy
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/bias"
self.predictions.bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)
self.predictions.bias.assign(init_bias)

init_decoder_bias = tf.zeros((new_num_tokens,))
init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy]
name = self.name + "/" + self.predictions.name + "/decoder_bias"
self.predictions.decoder_bias = self.add_weight(
shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name
)

self.predictions.decoder_bias.assign(init_decoder_bias)
def get_lm_head(self):
return self.predictions

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
Expand Down
77 changes: 46 additions & 31 deletions src/transformers/models/bart/modeling_tf_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,29 @@ def dummy_inputs(self):
}
return dummy_inputs

def get_input_embeddings(self):
base_model = getattr(self, self.base_model_prefix, self)

return base_model.shared

def set_input_embeddings(self, value):
base_model = getattr(self, self.base_model_prefix, self)

try:
base_model.shared.weight = value
except AttributeError:
self(self.dummy_inputs)
base_model.shared.weight = value

base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0]

with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass

embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name)
base_model.encoder.set_embed_tokens(embed_tokens)
base_model.decoder.set_embed_tokens(embed_tokens)

@tf.function(
input_signature=[
{
Expand Down Expand Up @@ -634,6 +657,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
else None
)

def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens

def call(
self,
input_ids=None,
Expand Down Expand Up @@ -791,6 +817,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm

def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens

def call(
self,
input_ids=None,
Expand Down Expand Up @@ -1009,6 +1038,9 @@ def __init__(self, config: BartConfig, *inputs, **kwargs):
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")

def get_encoder(self):
return self.encoder

def get_decoder(self):
return self.decoder

Expand Down Expand Up @@ -1134,15 +1166,6 @@ def serving_output(self, output):
encoder_attentions=enc_attns,
)

def get_input_embeddings(self):
return self.shared

def set_input_embeddings(self, value):
self.shared = value

def get_output_embeddings(self):
return self.shared


@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.",
Expand All @@ -1166,22 +1189,20 @@ def __init__(self, config, *inputs, **kwargs):
def get_decoder(self):
return self.model.decoder

def resize_token_embeddings(self, new_num_tokens):
super().resize_token_embeddings(new_num_tokens=new_num_tokens)

# BART is a special case where the bias has two dimensions
# and not named just `bias`
if new_num_tokens is not None:
num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens)
init_bias = tf.zeros((new_num_tokens,))
init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy]
self.final_logits_bias = self.add_weight(
shape=(1, new_num_tokens),
initializer="zeros",
trainable=False,
name="final_logits_bias",
)
self.final_logits_bias.assign(init_bias)
def get_encoder(self):
return self.model.encoder

def get_output_embeddings(self):
return self.get_input_embeddings()

def set_output_embeddings(self, value):
self.set_input_embeddings(value)

def get_bias(self):
return {"final_logits_bias": self.final_logits_bias}

def set_bias(self, value):
self.final_logits_bias = value["final_logits_bias"]

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
Expand Down Expand Up @@ -1356,12 +1377,6 @@ def adjust_logits_during_generation(self, logits, cur_len, max_length):
else:
return logits

def get_output_embeddings(self):
return self.model.shared

def get_encoder(self):
return self.model.encoder

def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
Expand Down
35 changes: 22 additions & 13 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
""" TF 2.0 BERT model. """

import warnings
from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down Expand Up @@ -526,6 +527,20 @@ def build(self, input_shape):

super().build(input_shape)

def get_output_embeddings(self):
return self.input_embeddings

def set_output_embeddings(self, value):
self.input_embeddings.word_embeddings = value
self.input_embeddings.vocab_size = shape_list(value)[0]

def get_bias(self):
return {"bias": self.bias}

def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]

def call(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.input_embeddings(hidden_states, mode="linear")
Expand Down Expand Up @@ -582,7 +597,7 @@ def get_input_embeddings(self):

def set_input_embeddings(self, value):
self.embeddings.word_embeddings = value
self.embeddings.vocab_size = value.shape[0]
self.embeddings.vocab_size = shape_list(value)[0]

def _prune_heads(self, heads_to_prune):
"""
Expand Down Expand Up @@ -918,13 +933,11 @@ def __init__(self, config, *inputs, **kwargs):
self.nsp = TFBertNSPHead(config, name="nsp___cls")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

def get_output_embeddings(self):
return self.bert.embeddings

def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions

def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Expand Down Expand Up @@ -1044,13 +1057,11 @@ def __init__(self, config, *inputs, **kwargs):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

def get_output_embeddings(self):
return self.bert.embeddings

def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions

def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Expand Down Expand Up @@ -1149,13 +1160,11 @@ def __init__(self, config, *inputs, **kwargs):
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")

def get_output_embeddings(self):
return self.bert.embeddings

def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.mlm.predictions

def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name

@add_code_sample_docstrings(
Expand Down
27 changes: 19 additions & 8 deletions src/transformers/models/ctrl/modeling_tf_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.
""" TF 2.0 CTRL model."""

import warnings

import numpy as np
import tensorflow as tf

Expand Down Expand Up @@ -242,10 +244,7 @@ def get_input_embeddings(self):

def set_input_embeddings(self, value):
self.w.weight = value
self.w.vocab_size = value.shape[0]

def _resize_token_embeddings(self, new_num_tokens):
raise NotImplementedError
self.w.vocab_size = shape_list(value)[0]

def _prune_heads(self, heads_to_prune):
"""
Expand Down Expand Up @@ -618,6 +617,20 @@ def build(self, input_shape):
self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
super().build(input_shape)

def get_output_embeddings(self):
return self.input_embeddings

def set_output_embeddings(self, value):
self.input_embeddings.weight = value
self.input_embeddings.vocab_size = shape_list(value)[0]

def get_bias(self):
return {"bias": self.bias}

def set_bias(self, value):
self.bias = value["bias"]
self.vocab_size = shape_list(value["bias"])[0]

def call(self, hidden_states):
hidden_states = self.input_embeddings(hidden_states, mode="linear")
hidden_states = hidden_states + self.bias
Expand All @@ -638,13 +651,11 @@ def __init__(self, config, *inputs, **kwargs):

self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head")

def get_output_embeddings(self):
return self.lm_head.input_embeddings

def get_output_layer_with_bias(self):
def get_lm_head(self):
return self.lm_head

def get_prefix_bias_name(self):
warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning)
return self.name + "/" + self.lm_head.name

def prepare_inputs_for_generation(self, inputs, past, **kwargs):
Expand Down
Loading

0 comments on commit 1243ee7

Please sign in to comment.