diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 694984ac02e6..e9a21f89d112 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -530,6 +530,46 @@ def load_tf_weights(model, resolved_archive_file): return missing_layers, unexpected_layers +def init_copy_embeddings(old_embeddings, new_num_tokens): + r""" + This function aims to reduce the embeddings in case new_num_tokens < old_num_tokens or to pad with -1 in case + new_num_tokens > old_num_tokens. A mask is also computed in order to know which weight in the embeddings should be + kept or not. Example: + + - if new_num_tokens=5 and old_num_tokens=4 and old_embeddings=[w1,w2,w3,w4] + + - mask=[True,True,True,True,False] and current_weights=[w1,w2,w3,w4,-1] + - if new_num_tokens=4 and old_num_tokens=5 and old_embeddings=[w1,w2,w3,w4,w5] + + - mask=[True,True,True,True] and current_weights=[w1,w2,w3,w4] + """ + old_num_tokens, old_embedding_dim = shape_list(old_embeddings) + size_diff = new_num_tokens - old_num_tokens + + # initialize new embeddings + # Copy token embeddings from the previous ones + if tf.math.greater(size_diff, 0): + # if the new size is greater than the old one, we extend the current embeddings with a padding until getting new size + # and we create a mask to properly identify the padded values and be replaced by the values of the newly created + # embeddings + current_weights = tf.pad( + old_embeddings.value(), tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=-1 + ) + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + mask = tf.fill(tf.convert_to_tensor([num_tokens_to_copy, 1]), True) + mask = tf.pad(mask, tf.convert_to_tensor([[0, size_diff], [0, 0]]), constant_values=False) + else: + # if the new size if lower than the old one, we take the current embeddings until the new size + current_weights = tf.slice( + old_embeddings.value(), + tf.convert_to_tensor([0, 0]), + tf.convert_to_tensor([new_num_tokens, old_embedding_dim]), + ) + mask = tf.fill(tf.convert_to_tensor([new_num_tokens, 1]), True) + + return mask, current_weights + + class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): r""" Base class for all TF models. @@ -615,58 +655,132 @@ def serving_output(output): def get_input_embeddings(self) -> tf.keras.layers.Layer: """ - Returns the model's input embeddings. + Returns the model's input embeddings layer. Returns: - :obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states. + :obj:`tf.Variable`: The embeddings layer mapping vocabulary to hidden states. """ - base_model = getattr(self, self.base_model_prefix, self) + main_layer = getattr(self, self.base_model_prefix, self) - if base_model is not self: - return base_model.get_input_embeddings() + if main_layer is not self: + return main_layer.get_input_embeddings() else: raise NotImplementedError def set_input_embeddings(self, value): """ - Set model's input embeddings. + Set model's input embeddings Args: - value (:obj:`tf.keras.layers.Layer`): - A module mapping vocabulary to hidden states. + value (:obj:`tf.Variable`): + The new weights mapping hidden states to vocabulary. """ - base_model = getattr(self, self.base_model_prefix, self) - if base_model is not self: - base_model.set_input_embeddings(value) - else: - raise NotImplementedError + main_layer = getattr(self, self.base_model_prefix) + + if main_layer is None: + raise NotImplementedError("The model does not implements the base_model_prefix attribute.") - def get_output_embeddings(self) -> tf.keras.layers.Layer: + try: + main_layer.set_input_embeddings(value) + except AttributeError: + logger.info("Building the model") + self(self.dummy_inputs) + main_layer.set_input_embeddings(value) + + def get_output_embeddings(self) -> Union[None, tf.keras.layers.Layer]: """ Returns the model's output embeddings Returns: - :obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary. + :obj:`tf.Variable`: The new weights mapping vocabulary to hidden states. """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + + return lm_head.get_output_embeddings() + return None # Overwrite for models with output embeddings + def set_output_embeddings(self, value): + """ + Set model's output embeddings + + Args: + value (:obj:`tf.Variable`): + The new weights mapping hidden states to vocabulary. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + try: + lm_head.set_output_embeddings(value) + except AttributeError: + logger.info("Building the model") + self(self.dummy_inputs) + lm_head.set_output_embeddings(value) + def get_output_layer_with_bias(self) -> Union[None, tf.keras.layers.Layer]: """ Get the layer that handles a bias attribute in case the model has an LM head with weights tied to the - embeddings. + embeddings Return: :obj:`tf.keras.layers.Layer`: The layer that handles the bias, None if not an LM model. """ - return None + warnings.warn( + "The method get_output_layer_with_bias is deprecated. Please use `get_lm_head` instead.", FutureWarning + ) + return self.get_lm_head() def get_prefix_bias_name(self) -> Union[None, str]: """ - Get the concatenated prefix name of the bias from the model name to the parent layer. + Get the concatenated prefix name of the bias from the model name to the parent layer Return: :obj:`str`: The prefix name of the bias. """ + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return None + + def get_bias(self) -> Union[None, Dict[str, tf.Variable]]: + """ + Dict of bias attached to an LM head. The key represents the name of the bias attribute. + + Return: + :obj:`tf.Variable`: The weights representing the bias, None if not an LM model. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + try: + return lm_head.get_bias() + except AttributeError: + self(self.dummy_inputs) + + return lm_head.get_bias() + return None + + def set_bias(self, value): + """ + Set all the bias in the LM head. + + Args: + value (:obj:`Dict[tf.Variable]`): + All the new bias attached to an LM head. + """ + if self.get_lm_head() is not None: + lm_head = self.get_lm_head() + try: + lm_head.set_bias(value) + except AttributeError: + self(self.dummy_inputs) + lm_head.set_bias(value) + + def get_lm_head(self) -> tf.keras.layers.Layer: + """ + The LM Head layer. This method must be overwritten by all the models that have a lm head. + + Return: + :obj:`tf.keras.layers.Layer`: The LM head layer if the model has one, None if not. + """ return None def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable: @@ -685,143 +799,179 @@ def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable: Return: :obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model. """ - model_embeds = self._resize_token_embeddings(new_num_tokens) - if new_num_tokens is None: - return model_embeds + if new_num_tokens is None or new_num_tokens == self.config.vocab_size: + return self._get_word_embedding_weight(self.get_input_embeddings()) - return model_embeds + model_embeds = self._resize_token_embeddings(new_num_tokens) - def _resize_token_embeddings(self, new_num_tokens): - # get_input_embeddings and set_input_embeddings need to be implemented in base layer. - base_model = getattr(self, self.base_model_prefix, self) - old_embeddings = base_model.get_input_embeddings() - new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) - base_model.set_input_embeddings(new_embeddings) # Update base model and current model config self.config.vocab_size = new_num_tokens - base_model.vocab_size = new_num_tokens - return base_model.get_input_embeddings() - - def _get_word_embeddings(self, embeddings): - if hasattr(embeddings, "word_embeddings"): - # TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings - return embeddings.word_embeddings - elif hasattr(embeddings, "weight"): - # TFSharedEmbeddings - return embeddings.weight + + return model_embeds + + def _get_word_embedding_weight(self, embedding_layer): + if hasattr(embedding_layer, "word_embeddings"): + return embedding_layer.word_embeddings + elif hasattr(embedding_layer, "weight"): + return embedding_layer.weight + elif hasattr(embedding_layer, "decoder"): + return embedding_layer.decoder else: # Here we build the word embeddings weights if not exists. # And then we retry to get the attribute once built. - embeddings.build([]) - if hasattr(embeddings, "word_embeddings"): - # TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings - return embeddings.word_embeddings - elif hasattr(embeddings, "weight"): - # TFSharedEmbeddings - return embeddings.weight + self(self.dummy_inputs) + if hasattr(embedding_layer, "word_embeddings"): + return embedding_layer.word_embeddings + elif hasattr(embedding_layer, "weight"): + return embedding_layer.weight + elif hasattr(embedding_layer, "decoder"): + return embedding_layer.decoder else: - raise ValueError("word embedding is not defined.") + return None - def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable: + def _resize_token_embeddings(self, new_num_tokens): + old_embeddings = self._get_word_embedding_weight(self.get_input_embeddings()) + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + + # if word embeddings are not tied, make sure that lm head bias is resized as well + if self.get_bias() is not None: + old_lm_head_bias = self.get_bias() + new_lm_head_bias = self._get_resized_lm_head_bias(old_lm_head_bias, new_num_tokens) + + self.set_bias(new_lm_head_bias) + + # if word embeddings are not tied, make sure that lm head decoder is resized as well + if self.get_output_embeddings() is not None: + old_lm_head_decoder = self._get_word_embedding_weight(self.get_output_embeddings()) + new_lm_head_decoder = self._get_resized_lm_head_decoder(old_lm_head_decoder, new_num_tokens) + + self.set_output_embeddings(new_lm_head_decoder) + + self.set_input_embeddings(new_embeddings) + + return self.get_input_embeddings() + + def _get_resized_lm_head_bias(self, old_lm_head_bias, new_num_tokens): """ - Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly - initialized vectors at the end. Reducing the size will remove vectors from the end + Build a resized bias from the old ones. Increasing the size will add newly initialized vectors at the end. + Reducing the size will remove vectors from the end Args: - old_embeddings (:obj:`tf.Variable`): - Old embeddings to be resized. + old_lm_head_bias (:obj:`tf.Variable`): + Old lm head bias to be resized. new_num_tokens (:obj:`int`, `optional`): - New number of tokens in the embedding matrix. + New number of tokens in the linear matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove - vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens - :obj:`tf.Variable`` module of the model without doing anything. + vectors from the end. If not provided or :obj:`None`, just returns None Return: - :obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if - :obj:`new_num_tokens` is :obj:`None` + :obj:`tf.Variable`: Pointer to the resized bias. """ - word_embeddings = self._get_word_embeddings(old_embeddings) - bias_layer = self.get_output_layer_with_bias() + new_lm_head_bias = {} + + for attr, weight in old_lm_head_bias.items(): + first_dim, old_num_tokens = (None, shape_list(weight)[0]) if tf.rank(weight) == 1 else shape_list(weight) + size_diff = new_num_tokens - old_num_tokens + final_shape = [new_num_tokens] if first_dim is None else [first_dim, new_num_tokens] + + # initialize new bias + if tf.math.greater(size_diff, 0): + padding_shape = [[0, size_diff]] if first_dim is None else [[0, 0], [0, size_diff]] + current_bias = tf.pad(weight.value(), tf.convert_to_tensor(padding_shape), constant_values=-1) + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + mask_shape = [num_tokens_to_copy] if first_dim is None else [1, num_tokens_to_copy] + bias_mask = tf.fill(tf.convert_to_tensor(mask_shape), True) + bias_mask = tf.pad(bias_mask, tf.convert_to_tensor(padding_shape), constant_values=False) + else: + slice_from = [0] if first_dim is None else [0, 0] + current_bias = tf.slice( + weight.value(), tf.convert_to_tensor(slice_from), tf.convert_to_tensor(final_shape) + ) + bias_mask = tf.fill(tf.convert_to_tensor(final_shape), True) - if new_num_tokens is None: - return word_embeddings + new_bias = self.add_weight( + shape=final_shape, + initializer="zeros", + trainable=True, + name=weight.name.split(":")[0], + ) + init_bias = tf.where(bias_mask, current_bias, new_bias.value()) - old_num_tokens, old_embedding_dim = word_embeddings.shape + new_bias.assign(init_bias) + new_lm_head_bias[attr] = new_bias - if old_num_tokens == new_num_tokens: - return word_embeddings + return new_lm_head_bias - # initialize new embeddings - # todo: initializer range is not always passed in config. - init_range = getattr(self.config, "initializer_range", 0.02) - name = ( - self.name - + "/" - + self.base_model_prefix - + "/" - + old_embeddings.name - + "/" - + word_embeddings.name.split(":")[0] - ) - new_embeddings = self.add_weight( - name=name, - shape=[new_num_tokens, old_embedding_dim], - initializer=get_initializer(init_range), - dtype=tf.float32, - ) - init_weights = tf.make_ndarray(tf.make_tensor_proto(new_embeddings.value())) - - # Copy token embeddings from the previous weights - num_tokens_to_copy = min(old_num_tokens, new_num_tokens) - init_weights[:num_tokens_to_copy] = word_embeddings.value()[:num_tokens_to_copy, :] - new_embeddings.assign(init_weights) + def _get_resized_lm_head_decoder(self, old_lm_head_decoder, new_num_tokens): + """ + Build a resized decoder from the old ones. Increasing the size will add newly initialized vectors at the end. + Reducing the size will remove vectors from the end - if bias_layer is not None: - if not hasattr(bias_layer, "bias"): - bias_layer.build([]) + Args: + old_lm_head_decoder (:obj:`tf.Variable`): + Old lm head decoder to be resized. + new_num_tokens (:obj:`int`, `optional`): + New number of tokens in the linear matrix. - # Second check in order to be sure the attribute has been properly created - if not hasattr(bias_layer, "bias"): - raise ValueError("bias is not defined.") + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or :obj:`None`, just returns None - # initialize bias - init_bias = np.zeros((new_num_tokens,)) - init_bias[:num_tokens_to_copy] = bias_layer.bias.value()[ - :num_tokens_to_copy - ] # tf.make_ndarray(tf.make_tensor_proto(bias_layer.bias.value()))[:num_tokens_to_copy] + Return: + :obj:`tf.Variable`: Pointer to the resized decoder or None if the output embeddings are differents of the + input ones. + """ + new_lm_head_decoder = old_lm_head_decoder + is_input_output_equals = tf.reduce_any( + self._get_word_embedding_weight(self.get_input_embeddings()) == old_lm_head_decoder + ) - bias_layer.bias = self.add_weight( - shape=(new_num_tokens,), + if old_lm_head_decoder is not None and not is_input_output_equals: + old_embedding_dim = shape_list(old_lm_head_decoder)[1] + decoder_mask, current_decoder = init_copy_embeddings(old_lm_head_decoder, new_num_tokens) + new_lm_head_decoder = self.add_weight( + shape=(new_num_tokens, old_embedding_dim), initializer="zeros", trainable=True, - name=self.get_prefix_bias_name() + "/bias", + name=old_lm_head_decoder.name.split(":")[0], ) + init_decoder = tf.where(decoder_mask, current_decoder, new_lm_head_decoder.value()) - bias_layer.bias.assign(init_bias) + new_lm_head_decoder.assign(init_decoder) - output_embeddings = self.get_output_embeddings() + return new_lm_head_decoder - if output_embeddings is not None: - if self.get_input_embeddings() != output_embeddings: - if not hasattr(output_embeddings, "decoder"): - output_embeddings.build([]) + def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable: + """ + Build a resized Embedding weights from a provided token Embedding weights. Increasing the size will add newly + initialized vectors at the end. Reducing the size will remove vectors from the end - # Second check in order to be sure the attribute has been properly created - if not hasattr(output_embeddings, "decoder"): - raise ValueError("decoder is not defined.") + Args: + old_embeddings (:obj:`tf.Variable`): + Old embeddings to be resized. + new_num_tokens (:obj:`int`, `optional`): + New number of tokens in the embedding matrix. - # initialize decoder - init_weights = np.zeros((new_num_tokens, old_embedding_dim)) - init_weights[:num_tokens_to_copy] = output_embeddings.decoder.value()[:num_tokens_to_copy, :] + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens + :obj:`tf.Variable`` module of the model without doing anything. - output_embeddings.decoder = self.add_weight( - shape=(new_num_tokens, old_embedding_dim), - initializer="zeros", - trainable=True, - name=self.get_prefix_bias_name() + "/decoder/weight", - ) - output_embeddings.decoder.assign(init_weights) + Return: + :obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if + :obj:`new_num_tokens` is :obj:`None` + """ + old_embedding_dim = shape_list(old_embeddings)[1] + init_range = getattr(self.config, "initializer_range", 0.02) + embeddings_mask, current_embeddings = init_copy_embeddings(old_embeddings, new_num_tokens) + new_embeddings = self.add_weight( + name=old_embeddings.name.split(":")[0], + shape=[new_num_tokens, old_embedding_dim], + initializer=get_initializer(init_range), + dtype=tf.float32, + ) + init_embeddings = tf.where(embeddings_mask, current_embeddings, new_embeddings.value()) + + new_embeddings.assign(init_embeddings) return new_embeddings diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py index fd0c752f3725..1bd953f7db45 100644 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ b/src/transformers/models/albert/modeling_tf_albert.py @@ -470,6 +470,21 @@ def build(self, input_shape): super().build(input_shape) + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.word_embeddings = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias, "decoder_bias": self.decoder_bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.decoder_bias = value["decoder_bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.activation(hidden_states) @@ -505,10 +520,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] - - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + self.embeddings.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): """ @@ -834,34 +846,8 @@ def __init__(self, config, *inputs, **kwargs): self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions") self.sop_classifier = TFAlbertSOPHead(config, name="sop_classifier") - def get_output_embeddings(self): - return self.albert.embeddings - - def resize_token_embeddings(self, new_num_tokens): - super().resize_token_embeddings(new_num_tokens=new_num_tokens) - - # ALBERT is a special case where there are two bias to update - # even though self.bias is not used anywhere and is here - # just to make the loading weights from a PT model happy - if new_num_tokens is not None: - num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens) - self.predictions.vocab_size = num_tokens_to_copy - init_bias = tf.zeros((new_num_tokens,)) - init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy] - name = self.name + "/" + self.predictions.name + "/bias" - self.predictions.bias = self.add_weight( - shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name - ) - self.predictions.bias.assign(init_bias) - - init_decoder_bias = tf.zeros((new_num_tokens,)) - init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy] - name = self.name + "/" + self.predictions.name + "/decoder_bias" - self.predictions.decoder_bias = self.add_weight( - shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name - ) - - self.predictions.decoder_bias.assign(init_decoder_bias) + def get_lm_head(self): + return self.predictions @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFAlbertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) @@ -979,34 +965,8 @@ def __init__(self, config, *inputs, **kwargs): self.albert = TFAlbertMainLayer(config, add_pooling_layer=False, name="albert") self.predictions = TFAlbertMLMHead(config, self.albert.embeddings, name="predictions") - def get_output_embeddings(self): - return self.albert.embeddings - - def resize_token_embeddings(self, new_num_tokens): - super().resize_token_embeddings(new_num_tokens=new_num_tokens) - - # ALBERT is a special case where there are two bias to update - # even though self.bias is not used anywhere and is here - # just to make the loading weights from a PT model happy - if new_num_tokens is not None: - num_tokens_to_copy = min(self.predictions.bias.shape[0], new_num_tokens) - self.predictions.vocab_size = num_tokens_to_copy - init_bias = tf.zeros((new_num_tokens,)) - init_bias[:num_tokens_to_copy] = self.predictions.bias.value()[:num_tokens_to_copy] - name = self.name + "/" + self.predictions.name + "/bias" - self.predictions.bias = self.add_weight( - shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name - ) - self.predictions.bias.assign(init_bias) - - init_decoder_bias = tf.zeros((new_num_tokens,)) - init_decoder_bias[:num_tokens_to_copy] = self.predictions.decoder_bias.value()[:num_tokens_to_copy] - name = self.name + "/" + self.predictions.name + "/decoder_bias" - self.predictions.decoder_bias = self.add_weight( - shape=(new_num_tokens,), initializer="zeros", trainable=True, name=name - ) - - self.predictions.decoder_bias.assign(init_decoder_bias) + def get_lm_head(self): + return self.predictions @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index db26af101f87..10410d0a876b 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -481,6 +481,29 @@ def dummy_inputs(self): } return dummy_inputs + def get_input_embeddings(self): + base_model = getattr(self, self.base_model_prefix, self) + + return base_model.shared + + def set_input_embeddings(self, value): + base_model = getattr(self, self.base_model_prefix, self) + + try: + base_model.shared.weight = value + except AttributeError: + self(self.dummy_inputs) + base_model.shared.weight = value + + base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0] + + with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: + pass + + embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name) + base_model.encoder.set_embed_tokens(embed_tokens) + base_model.decoder.set_embed_tokens(embed_tokens) + @tf.function( input_signature=[ { @@ -634,6 +657,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings else None ) + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + def call( self, input_ids=None, @@ -791,6 +817,9 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[TFSharedEmbeddings self.dropout = tf.keras.layers.Dropout(config.dropout) self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + def call( self, input_ids=None, @@ -1009,6 +1038,9 @@ def __init__(self, config: BartConfig, *inputs, **kwargs): self.encoder = TFBartEncoder(config, embed_tokens, name="encoder") self.decoder = TFBartDecoder(config, embed_tokens, name="decoder") + def get_encoder(self): + return self.encoder + def get_decoder(self): return self.decoder @@ -1134,15 +1166,6 @@ def serving_output(self, output): encoder_attentions=enc_attns, ) - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, value): - self.shared = value - - def get_output_embeddings(self): - return self.shared - @add_start_docstrings( "The BART Model with a language modeling head. Can be used for summarization.", @@ -1166,22 +1189,20 @@ def __init__(self, config, *inputs, **kwargs): def get_decoder(self): return self.model.decoder - def resize_token_embeddings(self, new_num_tokens): - super().resize_token_embeddings(new_num_tokens=new_num_tokens) - - # BART is a special case where the bias has two dimensions - # and not named just `bias` - if new_num_tokens is not None: - num_tokens_to_copy = min(self.final_logits_bias.shape[0], new_num_tokens) - init_bias = tf.zeros((new_num_tokens,)) - init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy] - self.final_logits_bias = self.add_weight( - shape=(1, new_num_tokens), - initializer="zeros", - trainable=False, - name="final_logits_bias", - ) - self.final_logits_bias.assign(init_bias) + def get_encoder(self): + return self.model.encoder + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) + + def get_bias(self): + return {"final_logits_bias": self.final_logits_bias} + + def set_bias(self, value): + self.final_logits_bias = value["final_logits_bias"] @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @@ -1356,12 +1377,6 @@ def adjust_logits_during_generation(self, logits, cur_len, max_length): else: return logits - def get_output_embeddings(self): - return self.model.shared - - def get_encoder(self): - return self.model.encoder - def compute_loss(self, labels, logits): """CrossEntropyLoss that ignores pad tokens""" loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index 2549868f64f2..c99aad3d29e8 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -15,6 +15,7 @@ # limitations under the License. """ TF 2.0 BERT model. """ +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -526,6 +527,20 @@ def build(self, input_shape): super().build(input_shape) + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.word_embeddings = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.input_embeddings(hidden_states, mode="linear") @@ -582,7 +597,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] + self.embeddings.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): """ @@ -918,13 +933,11 @@ def __init__(self, config, *inputs, **kwargs): self.nsp = TFBertNSPHead(config, name="nsp___cls") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") - def get_output_embeddings(self): - return self.bert.embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.mlm.predictions def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @@ -1044,13 +1057,11 @@ def __init__(self, config, *inputs, **kwargs): self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") - def get_output_embeddings(self): - return self.bert.embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.mlm.predictions def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @@ -1153,13 +1164,11 @@ def __init__(self, config, *inputs, **kwargs): self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert") self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls") - def get_output_embeddings(self): - return self.bert.embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.mlm.predictions def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name @add_code_sample_docstrings( diff --git a/src/transformers/models/ctrl/modeling_tf_ctrl.py b/src/transformers/models/ctrl/modeling_tf_ctrl.py index 452cdaff1714..b7b12910a9cb 100644 --- a/src/transformers/models/ctrl/modeling_tf_ctrl.py +++ b/src/transformers/models/ctrl/modeling_tf_ctrl.py @@ -15,6 +15,8 @@ # limitations under the License. """ TF 2.0 CTRL model.""" +import warnings + import numpy as np import tensorflow as tf @@ -242,10 +244,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.w.weight = value - self.w.vocab_size = value.shape[0] - - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + self.w.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): """ @@ -620,6 +619,20 @@ def build(self, input_shape): self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") super().build(input_shape) + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = hidden_states + self.bias @@ -640,13 +653,11 @@ def __init__(self, config, *inputs, **kwargs): self.lm_head = TFCTRLLMHead(config, self.transformer.w, name="lm_head") - def get_output_embeddings(self): - return self.lm_head.input_embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.lm_head def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.lm_head.name def prepare_inputs_for_generation(self, inputs, past, **kwargs): diff --git a/src/transformers/models/distilbert/modeling_tf_distilbert.py b/src/transformers/models/distilbert/modeling_tf_distilbert.py index 175127976e79..c038dfb74fe8 100644 --- a/src/transformers/models/distilbert/modeling_tf_distilbert.py +++ b/src/transformers/models/distilbert/modeling_tf_distilbert.py @@ -16,6 +16,8 @@ TF 2.0 DistilBERT model """ +import warnings + import tensorflow as tf from ...activations_tf import get_tf_activation @@ -39,7 +41,6 @@ TFPreTrainedModel, TFQuestionAnsweringLoss, TFSequenceClassificationLoss, - TFSharedEmbeddings, TFTokenClassificationLoss, get_initializer, input_processing, @@ -72,9 +73,6 @@ def __init__(self, config, **kwargs): self.vocab_size = config.vocab_size self.dim = config.dim self.initializer_range = config.initializer_range - self.word_embeddings = TFSharedEmbeddings( - config.vocab_size, config.dim, initializer_range=config.initializer_range, name="word_embeddings" - ) # padding_idx=0) self.position_embeddings = tf.keras.layers.Embedding( config.max_position_embeddings, config.dim, @@ -652,6 +650,20 @@ def build(self, input_shape): self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") super().build(input_shape) + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.word_embeddings = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = hidden_states + self.bias @@ -675,13 +687,11 @@ def __init__(self, config, *inputs, **kwargs): self.vocab_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="vocab_layer_norm") self.vocab_projector = TFDistilBertLMHead(config, self.distilbert.embeddings, name="vocab_projector") - def get_output_embeddings(self): - return self.vocab_projector.input_embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.vocab_projector def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.vocab_projector.name @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) diff --git a/src/transformers/models/dpr/modeling_tf_dpr.py b/src/transformers/models/dpr/modeling_tf_dpr.py index 79bd4384d1d1..47e61d71bea8 100644 --- a/src/transformers/models/dpr/modeling_tf_dpr.py +++ b/src/transformers/models/dpr/modeling_tf_dpr.py @@ -577,7 +577,11 @@ def __init__(self, config: DPRConfig, *args, **kwargs): self.ctx_encoder = TFDPREncoderLayer(config, name="ctx_encoder") def get_input_embeddings(self): - return self.ctx_encoder.bert_model.get_input_embeddings() + try: + return self.ctx_encoder.bert_model.get_input_embeddings() + except AttributeError: + self(self.dummy_inputs) + return self.ctx_encoder.bert_model.get_input_embeddings() @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFDPRContextEncoderOutput, config_class=_CONFIG_FOR_DOC) @@ -675,7 +679,11 @@ def __init__(self, config: DPRConfig, *args, **kwargs): self.question_encoder = TFDPREncoderLayer(config, name="question_encoder") def get_input_embeddings(self): - return self.question_encoder.bert_model.get_input_embeddings() + try: + return self.question_encoder.bert_model.get_input_embeddings() + except AttributeError: + self(self.dummy_inputs) + return self.question_encoder.bert_model.get_input_embeddings() @add_start_docstrings_to_model_forward(TF_DPR_ENCODERS_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFDPRQuestionEncoderOutput, config_class=_CONFIG_FOR_DOC) @@ -772,7 +780,11 @@ def __init__(self, config: DPRConfig, *args, **kwargs): self.span_predictor = TFDPRSpanPredictorLayer(config, name="span_predictor") def get_input_embeddings(self): - return self.span_predictor.encoder.bert_model.get_input_embeddings() + try: + return self.span_predictor.encoder.bert_model.get_input_embeddings() + except AttributeError: + self(self.dummy_inputs) + return self.span_predictor.encoder.bert_model.get_input_embeddings() @add_start_docstrings_to_model_forward(TF_DPR_READER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFDPRReaderOutput, config_class=_CONFIG_FOR_DOC) diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py index c494ab062bb3..2f29b7a2fac7 100644 --- a/src/transformers/models/electra/modeling_tf_electra.py +++ b/src/transformers/models/electra/modeling_tf_electra.py @@ -14,6 +14,7 @@ # limitations under the License. """ TF Electra model. """ +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -511,10 +512,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] - - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + self.embeddings.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): """ @@ -919,6 +917,20 @@ def build(self, input_shape): super().build(input_shape) + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.word_embeddings = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states, training=False): hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = hidden_states + self.bias @@ -950,13 +962,11 @@ def __init__(self, config, **kwargs): self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") - def get_output_embeddings(self): - return self.electra.embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.generator_lm_head def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.generator_lm_head.name @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) diff --git a/src/transformers/models/flaubert/modeling_tf_flaubert.py b/src/transformers/models/flaubert/modeling_tf_flaubert.py index 09e42b383053..be9623381aa5 100644 --- a/src/transformers/models/flaubert/modeling_tf_flaubert.py +++ b/src/transformers/models/flaubert/modeling_tf_flaubert.py @@ -18,6 +18,7 @@ import itertools import random +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -481,6 +482,10 @@ def __init__(self, config, **kwargs): def get_input_embeddings(self): return self.embeddings + def set_input_embeddings(self, value): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + def call( self, input_ids=None, @@ -724,6 +729,20 @@ def build(self, input_shape): super().build(input_shape) + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = hidden_states + self.bias @@ -770,13 +789,11 @@ def __init__(self, config, *inputs, **kwargs): self.transformer = TFFlaubertMainLayer(config, name="transformer") self.pred_layer = TFFlaubertPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj") - def get_output_embeddings(self): - return self.pred_layer.input_embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.pred_layer def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.pred_layer.name def prepare_inputs_for_generation(self, inputs, **kwargs): diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index bf3540b7be1b..4e0cae89d607 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -14,6 +14,7 @@ # limitations under the License. """ TF 2.0 Funnel model. """ +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -787,7 +788,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] + self.embeddings.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models @@ -873,7 +874,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] + self.embeddings.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models @@ -992,6 +993,20 @@ def build(self, input_shape): self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") super().build(input_shape) + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.word_embeddings = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states, training=False): hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = hidden_states + self.bias @@ -1359,13 +1374,11 @@ def __init__(self, config, *inputs, **kwargs): self.funnel = TFFunnelMainLayer(config, name="funnel") self.lm_head = TFFunnelMaskedLMHead(config, self.funnel.embeddings, name="lm_head") - def get_output_embeddings(self): - return self.funnel.embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.lm_head def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.lm_head.name @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length")) diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index cf68bc09131e..da22252e87e2 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -243,7 +243,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.wte.weight = value - self.wte.vocab_size = self.wte.weight.shape[0] + self.wte.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): """ @@ -656,7 +656,10 @@ def __init__(self, config, *inputs, **kwargs): self.transformer = TFGPT2MainLayer(config, name="transformer") def get_output_embeddings(self): - return self.transformer.wte + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) def prepare_inputs_for_generation(self, inputs, past, **kwargs): # only last token for inputs_ids if past is defined in kwargs @@ -779,9 +782,6 @@ def __init__(self, config, *inputs, **kwargs): config, initializer_range=config.initializer_range, name="multiple_choice_head" ) - def get_output_embeddings(self): - return self.transformer.wte - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFGPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -953,9 +953,6 @@ def __init__(self, config, *inputs, **kwargs): ) self.transformer = TFGPT2MainLayer(config, name="transformer") - def get_output_embeddings(self): - return self.transformer.wte - @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 69f315c96eaf..373726674f4d 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -1182,6 +1182,44 @@ def dummy_inputs(self): } return dummy_inputs + def get_input_embeddings(self): + base_model = getattr(self, self.base_model_prefix, self) + + return base_model.shared + + def set_input_embeddings(self, value): + base_model = getattr(self, self.base_model_prefix, self) + + try: + base_model.shared.weight = value + except AttributeError: + self(self.dummy_inputs) + base_model.shared.weight = value + + base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0] + + with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: + pass + + embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name) + base_model.encoder.set_embed_tokens(embed_tokens) + base_model.decoder.set_embed_tokens(embed_tokens) + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), + "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"), + "decoder_attention_mask": tf.TensorSpec((None, None), tf.int32, name="decoder_attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + @dataclass # Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput with TFLongformer->TFLEDEncoder @@ -1483,6 +1521,9 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[TFSharedEmbeddings] self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + def call( self, input_ids=None, @@ -1714,6 +1755,9 @@ def __init__(self, config: LEDConfig, embed_tokens: Optional[TFSharedEmbeddings] self.dropout = tf.keras.layers.Dropout(config.dropout) + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + def call( self, input_ids=None, @@ -1921,6 +1965,9 @@ def __init__(self, config: LEDConfig, *inputs, **kwargs): self.encoder = TFLEDEncoder(config, embed_tokens, name="encoder") self.decoder = TFLEDDecoder(config, embed_tokens, name="decoder") + def get_encoder(self): + return self.encoder + def get_decoder(self): return self.decoder @@ -2047,15 +2094,6 @@ def serving_output(self, output): encoder_global_attentions=enc_g_attns, ) - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, value): - self.shared = value - - def get_output_embeddings(self): - return self.shared - @add_start_docstrings( "The LED Model with a language modeling head. Can be used for summarization.", @@ -2079,22 +2117,20 @@ def __init__(self, config, *inputs, **kwargs): def get_decoder(self): return self.led.decoder - def resize_token_embeddings(self, new_num_tokens): - super().resize_token_embeddings(new_num_tokens=new_num_tokens) - - # LED is a special case where the bias has two dimensions - # and not named just `bias` - if new_num_tokens is not None: - num_tokens_to_copy = min(shape_list(self.final_logits_bias), new_num_tokens) - init_bias = tf.zeros((new_num_tokens,)) - init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy] - self.final_logits_bias = self.add_weight( - shape=(1, new_num_tokens), - initializer="zeros", - trainable=False, - name="final_logits_bias", - ) - self.final_logits_bias.assign(init_bias) + def get_encoder(self): + return self.led.encoder + + def get_bias(self): + return {"final_logits_bias": self.final_logits_bias} + + def set_bias(self, value): + self.final_logits_bias = value["final_logits_bias"] + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @@ -2266,12 +2302,6 @@ def _reorder_cache(past, beam_idx): ) return (past[0], reordered_past) - def get_output_embeddings(self): - return self.led.shared - - def get_encoder(self): - return self.led.encoder - def compute_loss(self, labels, logits): """CrossEntropyLoss that ignores pad tokens""" loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 4a1a294091a6..66bfd5d9fc05 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -14,6 +14,7 @@ # limitations under the License. """Tensorflow Longformer model. """ +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -437,6 +438,20 @@ def build(self, input_shape): super().build(input_shape) + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.word_embeddings = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.act(hidden_states) @@ -1602,7 +1617,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] + self.embeddings.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): """ @@ -2040,13 +2055,11 @@ def __init__(self, config, *inputs, **kwargs): self.longformer = TFLongformerMainLayer(config, add_pooling_layer=False, name="longformer") self.lm_head = TFLongformerLMHead(config, self.longformer.embeddings, name="lm_head") - def get_output_embeddings(self): - return self.lm_head.decoder - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.lm_head def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.lm_head.name @add_start_docstrings_to_model_forward(LONGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length")) diff --git a/src/transformers/models/lxmert/modeling_tf_lxmert.py b/src/transformers/models/lxmert/modeling_tf_lxmert.py index 1e68348ce6f3..cc5d93aa9502 100644 --- a/src/transformers/models/lxmert/modeling_tf_lxmert.py +++ b/src/transformers/models/lxmert/modeling_tf_lxmert.py @@ -16,6 +16,7 @@ # limitations under the License. """ TF 2.0 LXMERT model. """ +import warnings from dataclasses import dataclass from typing import Dict, Optional, Tuple @@ -706,10 +707,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] - - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + self.embeddings.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): raise NotImplementedError @@ -1103,6 +1101,20 @@ def build(self, input_shape): self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") super().build(input_shape) + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.word_embeddings = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.input_embeddings(hidden_states, mode="linear") @@ -1292,13 +1304,11 @@ def dummy_inputs(self): **({"obj_labels": obj_labels} if self.config.task_obj_predict else {}), } - def get_output_embeddings(self): - return self.lxmert.embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.cls.predictions def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.cls.name + "/" + self.cls.predictions.name @add_start_docstrings_to_model_forward(LXMERT_INPUTS_DOCSTRING) diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index ceaba9390575..8bc5742e5d1d 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -15,6 +15,7 @@ # limitations under the License. """ TF 2.0 MobileBERT model. """ +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -665,6 +666,20 @@ def build(self, input_shape): ) super().build(input_shape) + def get_output_embeddings(self): + return self + + def set_output_embeddings(self, value): + self.decoder = value + self.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = tf.matmul(hidden_states, tf.concat([tf.transpose(self.decoder), self.dense], axis=0)) @@ -704,10 +719,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] - - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + self.embeddings.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): """ @@ -1038,13 +1050,11 @@ def __init__(self, config, *inputs, **kwargs): self.predictions = TFMobileBertMLMHead(config, name="predictions___cls") self.seq_relationship = TFMobileBertOnlyNSPHead(2, name="seq_relationship___cls") - def get_output_embeddings(self): - return self.predictions.predictions - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.predictions.predictions def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.predictions.name + "/" + self.predictions.predictions.name @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @@ -1148,13 +1158,11 @@ def __init__(self, config, *inputs, **kwargs): self.mobilebert = TFMobileBertMainLayer(config, add_pooling_layer=False, name="mobilebert") self.mlm = TFMobileBertMLMHead(config, name="mlm___cls") - def get_output_embeddings(self): - return self.mlm.predictions - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.mlm.predictions def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name @add_start_docstrings_to_model_forward(MOBILEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) diff --git a/src/transformers/models/mpnet/modeling_tf_mpnet.py b/src/transformers/models/mpnet/modeling_tf_mpnet.py index 70a7a29e21bc..278ab625b40d 100644 --- a/src/transformers/models/mpnet/modeling_tf_mpnet.py +++ b/src/transformers/models/mpnet/modeling_tf_mpnet.py @@ -17,6 +17,7 @@ import math +import warnings import tensorflow as tf @@ -541,7 +542,7 @@ def get_input_embeddings(self): # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] + self.embeddings.vocab_size = shape_list(value)[0] # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads def _prune_heads(self, heads_to_prune): @@ -839,6 +840,20 @@ def build(self, input_shape): super().build(input_shape) + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.word_embeddings = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, features): x = self.dense(features) x = self.act(x) @@ -861,13 +876,11 @@ def __init__(self, config, *inputs, **kwargs): self.mpnet = TFMPNetMainLayer(config, name="mpnet") self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head") - def get_output_embeddings(self): - return self.mpnet.embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.lm_head def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.lm_head.name @add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length")) diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py index 4cd689dad88c..1b68ef8e6874 100644 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ b/src/transformers/models/openai/modeling_tf_openai.py @@ -219,7 +219,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.tokens_embed.weight = value - self.tokens_embed.vocab_size = value.shape[0] + self.tokens_embed.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): """ @@ -580,7 +580,10 @@ def __init__(self, config, *inputs, **kwargs): self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") def get_output_embeddings(self): - return self.transformer.tokens_embed + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( @@ -688,9 +691,6 @@ def __init__(self, config, *inputs, **kwargs): config, initializer_range=config.initializer_range, name="multiple_choice_head" ) - def get_output_embeddings(self): - return self.transformer.tokens_embed - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -850,9 +850,6 @@ def __init__(self, config, *inputs, **kwargs): ) self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") - def get_output_embeddings(self): - return self.transformer.tokens_embed - @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index 042ec47f48c5..13c8eea9c8a7 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -15,6 +15,8 @@ # limitations under the License. """ TF 2.0 RoBERTa model. """ +import warnings + import tensorflow as tf from ...activations_tf import get_tf_activation @@ -502,7 +504,7 @@ def get_input_embeddings(self): # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value - self.embeddings.vocab_size = value.shape[0] + self.embeddings.vocab_size = shape_list(value)[0] # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads def _prune_heads(self, heads_to_prune): @@ -826,6 +828,20 @@ def build(self, input_shape): super().build(input_shape) + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.word_embeddings = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.act(hidden_states) @@ -848,13 +864,11 @@ def __init__(self, config, *inputs, **kwargs): self.roberta = TFRobertaMainLayer(config, add_pooling_layer=False, name="roberta") self.lm_head = TFRobertaLMHead(config, self.roberta.embeddings, name="lm_head") - def get_output_embeddings(self): - return self.lm_head.decoder - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.lm_head def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.lm_head.name @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 0dd154782001..6cb4419c7e4d 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -574,15 +574,6 @@ def __init__(self, config, embed_tokens=None, **kwargs): self.final_layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="final_layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) - def get_input_embeddings(self): - return self.embed_tokens - - def set_embed_tokens(self, embed_tokens): - self.embed_tokens = embed_tokens - - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models - def _prune_heads(self, heads_to_prune): raise NotImplementedError # Not implemented yet in the library fr TF 2.0 models @@ -840,6 +831,26 @@ def serving(self, inputs): return self.serving_output(output) + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + try: + self.shared.weight = value + except AttributeError: + self(self.dummy_inputs) + self.shared.weight = value + + self.shared.vocab_size = shape_list(value)[0] + # retrieve correct absolute scope for embed token wrapper + with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: + pass + # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. + embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) + self.encoder.embed_tokens = embed_tokens + if hasattr(self, "decoder"): + self.decoder.embed_tokens = embed_tokens + def _shift_right(self, input_ids): decoder_start_token_id = self.config.decoder_start_token_id pad_token_id = self.config.pad_token_id @@ -1051,20 +1062,6 @@ def __init__(self, config, *inputs, **kwargs): decoder_config.is_decoder = True self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared.weight = new_embeddings - self.shared.vocab_size = self.shared.weight.shape[0] - # retrieve correct absolute scope for embed token wrapper - with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: - pass - # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. - embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) - self.encoder.set_embed_tokens(embed_tokens) - self.decoder.set_embed_tokens(embed_tokens) - def get_encoder(self): return self.encoder @@ -1223,24 +1220,23 @@ def __init__(self, config, *inputs, **kwargs): if not config.tie_word_embeddings: self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head") - def get_input_embeddings(self): - return self.shared - def get_output_embeddings(self): if self.config.tie_word_embeddings: - return self.shared + return self.get_input_embeddings() else: - return self.lm_head + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + return tf.transpose(self.lm_head.kernel) - def set_input_embeddings(self, new_embeddings): - self.shared.weight = new_embeddings - # retrieve correct absolute scope for embed token wrapper - with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: - pass - # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. - embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) - self.encoder.set_embed_tokens(embed_tokens) - self.decoder.set_embed_tokens(embed_tokens) + def set_output_embeddings(self, value): + if self.config.tie_word_embeddings: + self.set_input_embeddings(value) + else: + self.lm_head = tf.keras.layers.Dense(shape_list(value)[0], use_bias=False, name="lm_head") + # in a dense layer the kernel has a shape (last_dim, units), for us (dim, num_tokens) + # value has a shape (num_tokens, dim) then needs to be transposed + transposed_value = tf.transpose(value) + self.lm_head.kernel = transposed_value def get_encoder(self): return self.encoder @@ -1359,9 +1355,9 @@ def call( # T5v1.1 does not tie output word embeddings and thus does not require downscaling if self.config.tie_word_embeddings: sequence_output = sequence_output * (self.model_dim ** -0.5) - logits = self.get_output_embeddings()(sequence_output, mode="linear") + logits = self.shared(sequence_output, mode="linear") else: - logits = self.get_output_embeddings()(sequence_output) + logits = self.lm_head(sequence_output) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) @@ -1489,19 +1485,6 @@ def __init__(self, config, *inputs, **kwargs): encoder_config.use_cache = False self.encoder = TFT5MainLayer(encoder_config, embed_tokens, name="encoder") - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, new_embeddings): - self.shared.weight = new_embeddings - self.shared.vocab_size = self.shared.weight.shape[0] - # retrieve correct absolute scope for embed token wrapper - with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name: - pass - # Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope. - embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name) - self.encoder.set_embed_tokens(embed_tokens) - def get_encoder(self): return self.encoder diff --git a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py index 9aec7949bda1..ac43eb300154 100644 --- a/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py @@ -468,9 +468,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): raise NotImplementedError - def _resize_token_embeddings(self, new_num_tokens): - return self.word_emb - def backward_compatible(self): self.sample_softmax = -1 @@ -909,25 +906,6 @@ def serving_output(self, output): ) -class TFTransfoXLMHead(tf.keras.layers.Layer): - def __init__(self, config, input_embeddings, **kwargs): - super().__init__(**kwargs) - self.vocab_size = config.vocab_size - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.input_embeddings = input_embeddings - - def build(self, input_shape): - self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") - super().build(input_shape) - - def call(self, hidden_states): - hidden_states = self.input_embeddings(hidden_states, mode="linear") - hidden_states = hidden_states + self.bias - return hidden_states - - @add_start_docstrings( """ The Transformer-XL Model with a language modeling head on top (adaptive softmax with weights tied to the adaptive @@ -948,6 +926,9 @@ def __init__(self, config): config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val, name="crit" ) + def _resize_token_embeddings(self, new_num_tokens): + raise NotImplementedError() + def get_output_embeddings(self): """Double-check if you are using adaptive softmax.""" if len(self.crit.out_layers) > 0: diff --git a/src/transformers/models/xlm/modeling_tf_xlm.py b/src/transformers/models/xlm/modeling_tf_xlm.py index 2c08d2685022..f604d72dcb7f 100644 --- a/src/transformers/models/xlm/modeling_tf_xlm.py +++ b/src/transformers/models/xlm/modeling_tf_xlm.py @@ -17,6 +17,7 @@ """ import itertools +import warnings from dataclasses import dataclass from typing import Optional, Tuple @@ -330,10 +331,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings.weight = value - self.embeddings.vocab_size = value.shape[0] - - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError + self.embeddings.vocab_size = shape_list(value)[0] def _prune_heads(self, heads_to_prune): """ @@ -790,6 +788,20 @@ def build(self, input_shape): super().build(input_shape) + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = hidden_states + self.bias @@ -810,13 +822,11 @@ def __init__(self, config, *inputs, **kwargs): self.transformer = TFXLMMainLayer(config, name="transformer") self.pred_layer = TFXLMPredLayer(config, self.transformer.embeddings, name="pred_layer_._proj") - def get_output_embeddings(self): - return self.pred_layer.input_embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.pred_layer def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.pred_layer.name def prepare_inputs_for_generation(self, inputs, **kwargs): diff --git a/src/transformers/models/xlnet/modeling_tf_xlnet.py b/src/transformers/models/xlnet/modeling_tf_xlnet.py index c93ed3124efd..7756767f72db 100644 --- a/src/transformers/models/xlnet/modeling_tf_xlnet.py +++ b/src/transformers/models/xlnet/modeling_tf_xlnet.py @@ -17,6 +17,7 @@ TF 2.0 XLNet model. """ +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple @@ -407,6 +408,20 @@ def build(self, input_shape): self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") super().build(input_shape) + def get_output_embeddings(self): + return self.input_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.weight = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + def call(self, hidden_states): hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = hidden_states + self.bias @@ -450,7 +465,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.word_embedding.weight = value - self.word_embedding.vocab_size = value.shape[0] + self.word_embedding.vocab_size = shape_list(value)[0] def build(self, input_shape): initializer = get_initializer(self.initializer_range) @@ -458,9 +473,6 @@ def build(self, input_shape): shape=(1, 1, self.d_model), initializer=initializer, trainable=True, name="mask_emb" ) - def _resize_token_embeddings(self, new_num_tokens): - raise NotImplementedError - def _prune_heads(self, heads_to_prune): raise NotImplementedError @@ -1230,13 +1242,11 @@ def __init__(self, config, *inputs, **kwargs): self.transformer = TFXLNetMainLayer(config, name="transformer") self.lm_loss = TFXLNetLMHead(config, self.transformer.word_embedding, name="lm_loss") - def get_output_embeddings(self): - return self.lm_loss.input_embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.lm_loss def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) return self.name + "/" + self.lm_loss.name def prepare_inputs_for_generation(self, inputs, past, use_mems=None, **kwargs): diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 50b53292110f..dbdc8c8e1e5d 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -460,6 +460,20 @@ def build(self, input_shape): self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") super().build(input_shape) + + def get_output_embeddings(self): + return self.input_embeddings.word_embeddings + + def set_output_embeddings(self, value): + self.input_embeddings.word_embeddings = value + self.input_embeddings.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] def call(self, hidden_states): hidden_states = self.transform(hidden_states) @@ -803,15 +817,9 @@ def __init__(self, config, *inputs, **kwargs): self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}") self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls") - def get_output_embeddings(self): - return self.{{cookiecutter.lowercase_modelname}}.embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.mlm.predictions - def get_prefix_bias_name(self): - return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name - @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, @@ -909,15 +917,9 @@ def __init__(self, config, *inputs, **kwargs): self.{{cookiecutter.lowercase_modelname}} = TF{{cookiecutter.camelcase_modelname}}MainLayer(config, name="{{cookiecutter.lowercase_modelname}}") self.mlm = TF{{cookiecutter.camelcase_modelname}}MLMHead(config, self.{{cookiecutter.lowercase_modelname}}.embeddings, name="mlm___cls") - def get_output_embeddings(self): - return self.{{cookiecutter.lowercase_modelname}}.embeddings - - def get_output_layer_with_bias(self): + def get_lm_head(self): return self.mlm.predictions - def get_prefix_bias_name(self): - return self.name + "/" + self.mlm.name + "/" + self.mlm.predictions.name - @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="{{cookiecutter.checkpoint_identifier}}", @@ -1875,6 +1877,29 @@ def dummy_inputs(self): } return dummy_inputs + def get_input_embeddings(self): + base_model = getattr(self, self.base_model_prefix, self) + + return base_model.shared + + def set_input_embeddings(self, value): + base_model = getattr(self, self.base_model_prefix, self) + + try: + base_model.shared.weight = value + except AttributeError: + self(self.dummy_inputs) + base_model.shared.weight = value + + base_model.shared.vocab_size = shape_list(base_model.shared.weight)[0] + + with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name: + pass + + embed_tokens = TFWrappedEmbeddings(base_model.shared, abs_scope_name=shared_abs_scope_name) + base_model.encoder.set_embed_tokens(embed_tokens) + base_model.decoder.set_embed_tokens(embed_tokens) + @tf.function( input_signature=[ { @@ -2004,6 +2029,9 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tok self.layers = [TF{{cookiecutter.camelcase_modelname}}EncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)] self.layernorm_embedding = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding") + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + def call( self, input_ids=None, @@ -2144,6 +2172,9 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, embed_tok self.dropout = tf.keras.layers.Dropout(config.dropout) + def set_embed_tokens(self, embed_tokens): + self.embed_tokens = embed_tokens + def call( self, input_ids=None, @@ -2351,6 +2382,9 @@ def __init__(self, config: {{cookiecutter.camelcase_modelname}}Config, *inputs, self.encoder = TF{{cookiecutter.camelcase_modelname}}Encoder(config, embed_tokens, name="encoder") self.decoder = TF{{cookiecutter.camelcase_modelname}}Decoder(config, embed_tokens, name="decoder") + def get_encoder(self): + return self.encoder + def get_decoder(self): return self.decoder @@ -2471,15 +2505,6 @@ def serving_output(self, output): encoder_attentions=enc_attns, ) - def get_input_embeddings(self): - return self.shared - - def set_input_embeddings(self, value): - self.shared = value - - def get_output_embeddings(self): - return self.shared - @add_start_docstrings( "The {{cookiecutter.uppercase_modelname}} Model with a language modeling head. Can be used for summarization.", @@ -2502,23 +2527,21 @@ def __init__(self, config, *inputs, **kwargs): def get_decoder(self): return self.model.decoder + + def get_encoder(self): + return self.model.encoder - def resize_token_embeddings(self, new_num_tokens): - super().resize_token_embeddings(new_num_tokens=new_num_tokens) - - # {{cookiecutter.uppercase_modelname}} is a special case where the bias has two dimensions - # and not named just `bias` - if new_num_tokens is not None: - num_tokens_to_copy = min(shape_list(self.final_logits_bias)[0], new_num_tokens) - init_bias = tf.zeros((new_num_tokens,)) - init_bias[:num_tokens_to_copy] = self.final_logits_bias.value()[:num_tokens_to_copy] - self.final_logits_bias = self.add_weight( - shape=(1, new_num_tokens), - initializer="zeros", - trainable=False, - name="final_logits_bias", - ) - self.final_logits_bias.assign(init_bias) + def get_bias(self): + return {"final_logits_bias": self.final_logits_bias} + + def set_bias(self, value): + self.final_logits_bias = value["final_logits_bias"] + + def get_output_embeddings(self): + return self.get_input_embeddings() + + def set_output_embeddings(self, value): + self.set_input_embeddings(value) @add_start_docstrings_to_model_forward({{cookiecutter.uppercase_modelname}}_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @@ -2682,12 +2705,6 @@ def _reorder_cache(past, beam_idx): ) return (past[0], reordered_past) - def get_output_embeddings(self): - return self.model.shared - - def get_encoder(self): - return self.model.encoder - def compute_loss(self, labels, logits): """CrossEntropyLoss that ignores pad tokens""" loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 9b6e190b1d63..7ccd691ca622 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -486,10 +486,82 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None + + def test_resize_token_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _get_word_embedding_weight(model, embedding_layer): + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + # Here we build the word embeddings weights if not exists. + # And then we retry to get the attribute once built. + model(model.dummy_inputs) + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + return None + + for model_class in self.all_model_classes: + for size in [config.vocab_size - 10, config.vocab_size + 10, None]: + # build the embeddings + model = model_class(config=config) + old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + old_final_logits_bias = model.get_bias() + + # reshape the embeddings + model.resize_token_embeddings(size) + new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + new_final_logits_bias = model.get_bias() + + # check that the resized embeddings size matches the desired size. + assert_size = size if size is not None else config.vocab_size + + self.assertEqual(new_input_embeddings.shape[0], assert_size) + + # check that weights remain the same after resizing + models_equal = True + for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_output_embeddings is not None and new_output_embeddings is not None: + self.assertEqual(new_output_embeddings.shape[0], assert_size) + + models_equal = True + for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_final_logits_bias is not None and new_final_logits_bias is not None: + old_final_logits_bias = old_final_logits_bias["final_logits_bias"] + new_final_logits_bias = new_final_logits_bias["final_logits_bias"] + self.assertEqual(new_final_logits_bias.shape[0], 1) + self.assertEqual(new_final_logits_bias.shape[1], assert_size) + + models_equal = True + for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()): + for p1, p2 in zip(old, new): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): diff --git a/tests/test_modeling_tf_albert.py b/tests/test_modeling_tf_albert.py index 354e116671a8..5902c2b01fb3 100644 --- a/tests/test_modeling_tf_albert.py +++ b/tests/test_modeling_tf_albert.py @@ -274,14 +274,24 @@ def test_for_question_answering(self): def test_model_common_attributes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + list_lm_models = [TFAlbertForPreTraining, TFAlbertForMaskedLM] for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in list_lm_models: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None @slow def test_model_from_pretrained(self): diff --git a/tests/test_modeling_tf_bart.py b/tests/test_modeling_tf_bart.py index 58756fbf349d..2765eb7a418a 100644 --- a/tests/test_modeling_tf_bart.py +++ b/tests/test_modeling_tf_bart.py @@ -159,10 +159,82 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None + + def test_resize_token_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _get_word_embedding_weight(model, embedding_layer): + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + # Here we build the word embeddings weights if not exists. + # And then we retry to get the attribute once built. + model(model.dummy_inputs) + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + return None + + for model_class in self.all_model_classes: + for size in [config.vocab_size - 10, config.vocab_size + 10, None]: + # build the embeddings + model = model_class(config=config) + old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + old_final_logits_bias = model.get_bias() + + # reshape the embeddings + model.resize_token_embeddings(size) + new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + new_final_logits_bias = model.get_bias() + + # check that the resized embeddings size matches the desired size. + assert_size = size if size is not None else config.vocab_size + + self.assertEqual(new_input_embeddings.shape[0], assert_size) + + # check that weights remain the same after resizing + models_equal = True + for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_output_embeddings is not None and new_output_embeddings is not None: + self.assertEqual(new_output_embeddings.shape[0], assert_size) + + models_equal = True + for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_final_logits_bias is not None and new_final_logits_bias is not None: + old_final_logits_bias = old_final_logits_bias["final_logits_bias"] + new_final_logits_bias = new_final_logits_bias["final_logits_bias"] + self.assertEqual(new_final_logits_bias.shape[0], 1) + self.assertEqual(new_final_logits_bias.shape[1], assert_size) + + models_equal = True + for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()): + for p1, p2 in zip(old, new): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI diff --git a/tests/test_modeling_tf_bert.py b/tests/test_modeling_tf_bert.py index a1d2bb747acf..7289a5d47332 100644 --- a/tests/test_modeling_tf_bert.py +++ b/tests/test_modeling_tf_bert.py @@ -340,15 +340,17 @@ def test_model_common_attributes(self): assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) if model_class in list_lm_models: - x = model.get_output_layer_with_bias() + x = model.get_output_embeddings() assert isinstance(x, tf.keras.layers.Layer) - name = model.get_prefix_bias_name() - assert isinstance(name, str) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) else: - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() + x = model.get_output_embeddings() assert x is None + name = model.get_bias() + assert name is None def test_custom_load_tf_weights(self): model, output_loading_info = TFBertForTokenClassification.from_pretrained( diff --git a/tests/test_modeling_tf_blenderbot.py b/tests/test_modeling_tf_blenderbot.py index 662d33c4c341..7cb63f47d619 100644 --- a/tests/test_modeling_tf_blenderbot.py +++ b/tests/test_modeling_tf_blenderbot.py @@ -57,29 +57,93 @@ def test_inputs_embeds(self): # inputs_embeds not supported pass - def test_saved_model_with_hidden_states_output(self): - # Should be uncommented during patrick TF refactor - pass - - def test_saved_model_with_attentions_output(self): - # Should be uncommented during patrick TF refactor - pass - def test_model_common_attributes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI pass + def test_resize_token_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _get_word_embedding_weight(model, embedding_layer): + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + # Here we build the word embeddings weights if not exists. + # And then we retry to get the attribute once built. + model(model.dummy_inputs) + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + return None + + for model_class in self.all_model_classes: + for size in [config.vocab_size - 10, config.vocab_size + 10, None]: + # build the embeddings + model = model_class(config=config) + old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + old_final_logits_bias = model.get_bias() + + # reshape the embeddings + model.resize_token_embeddings(size) + new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + new_final_logits_bias = model.get_bias() + + # check that the resized embeddings size matches the desired size. + assert_size = size if size is not None else config.vocab_size + + self.assertEqual(new_input_embeddings.shape[0], assert_size) + + # check that weights remain the same after resizing + models_equal = True + for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_output_embeddings is not None and new_output_embeddings is not None: + self.assertEqual(new_output_embeddings.shape[0], assert_size) + + models_equal = True + for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_final_logits_bias is not None and new_final_logits_bias is not None: + old_final_logits_bias = old_final_logits_bias["final_logits_bias"] + new_final_logits_bias = new_final_logits_bias["final_logits_bias"] + self.assertEqual(new_final_logits_bias.shape[0], 1) + self.assertEqual(new_final_logits_bias.shape[1], assert_size) + + models_equal = True + for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()): + for p1, p2 in zip(old, new): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + @is_pt_tf_cross_test @require_tokenizers diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 702b531b6c19..db8436461720 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -41,7 +41,6 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - TFAdaptiveEmbedding, TFSharedEmbeddings, tf_top_k_top_p_filtering, ) @@ -671,18 +670,20 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) - assert isinstance(model.get_input_embeddings(), (tf.keras.layers.Layer, TFAdaptiveEmbedding)) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) if model_class in list_lm_models: - x = model.get_output_layer_with_bias() + x = model.get_output_embeddings() assert isinstance(x, tf.keras.layers.Layer) - name = model.get_prefix_bias_name() - assert isinstance(name, str) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) else: - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() + x = model.get_output_embeddings() assert x is None + name = model.get_bias() + assert name is None def test_determinism(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -809,26 +810,71 @@ def test_resize_token_embeddings(self): if not self.test_resize_embeddings: return config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - INPUT_SHAPE = [1, 10, config.hidden_size] + + def _get_word_embedding_weight(model, embedding_layer): + if hasattr(embedding_layer, "word_embeddings"): + return embedding_layer.word_embeddings + elif hasattr(embedding_layer, "weight"): + return embedding_layer.weight + elif hasattr(embedding_layer, "decoder"): + return embedding_layer.decoder + else: + # Here we build the word embeddings weights if not exists. + # And then we retry to get the attribute once built. + model(model.dummy_inputs) + if hasattr(embedding_layer, "word_embeddings"): + return embedding_layer.word_embeddings + elif hasattr(embedding_layer, "weight"): + return embedding_layer.weight + elif hasattr(embedding_layer, "decoder"): + return embedding_layer.decoder + else: + return None + for model_class in self.all_model_classes: for size in [config.vocab_size - 10, config.vocab_size + 10, None]: # build the embeddings model = model_class(config=config) - emb_old = model.get_input_embeddings() - emb_old.build(INPUT_SHAPE) + old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + old_bias = model.get_bias() + old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) # reshape the embeddings - new_embeddings = model._get_resized_embeddings(emb_old, size) - # # check that the resized embeddings size matches the desired size. + model.resize_token_embeddings(size) + new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + new_bias = model.get_bias() + new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + + # check that the resized embeddings size matches the desired size. assert_size = size if size is not None else config.vocab_size - self.assertEqual(new_embeddings.shape[0], assert_size) + self.assertEqual(new_input_embeddings.shape[0], assert_size) + # check that weights remain the same after resizing - emd_old_weights = model._get_word_embeddings(emb_old) models_equal = True - for p1, p2 in zip(emd_old_weights.numpy(), new_embeddings.numpy()): - if np.sum(abs(p1 - p2)) > 0: + for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: models_equal = False self.assertTrue(models_equal) + if old_bias is not None and new_bias is not None: + for old_weight, new_weight in zip(old_bias.values(), new_bias.values()): + self.assertEqual(new_weight.shape[0], assert_size) + + models_equal = True + for p1, p2 in zip(old_weight.value(), new_weight.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_output_embeddings is not None and new_output_embeddings is not None: + self.assertEqual(new_output_embeddings.shape[0], assert_size) + self.assertEqual(new_output_embeddings.shape[1], old_output_embeddings.shape[1]) + + models_equal = True + for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + def test_lm_head_model_random_no_beam_search_generate(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() input_ids = inputs_dict["input_ids"] diff --git a/tests/test_modeling_tf_ctrl.py b/tests/test_modeling_tf_ctrl.py index 4231c151bdd6..e2bcb69709f7 100644 --- a/tests/test_modeling_tf_ctrl.py +++ b/tests/test_modeling_tf_ctrl.py @@ -193,6 +193,33 @@ def test_ctrl_sequence_classification_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_ctrl_for_sequence_classification(*config_and_inputs) + def test_model_common_attributes(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + list_lm_models = [TFCTRLLMHeadModel] + list_other_models_with_output_ebd = [TFCTRLForSequenceClassification] + + for model_class in self.all_model_classes: + model = model_class(config) + assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) + + if model_class in list_lm_models: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + elif model_class in list_other_models_with_output_ebd: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert name is None + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None + @slow def test_model_from_pretrained(self): for model_name in TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_tf_gpt2.py b/tests/test_modeling_tf_gpt2.py index c0781a2947ae..93c919a57a3b 100644 --- a/tests/test_modeling_tf_gpt2.py +++ b/tests/test_modeling_tf_gpt2.py @@ -365,10 +365,17 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert name is None + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None def test_gpt2_sequence_classification_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index a6eb83a32676..0245985e1c0d 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -199,10 +199,82 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None + + def test_resize_token_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _get_word_embedding_weight(model, embedding_layer): + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + # Here we build the word embeddings weights if not exists. + # And then we retry to get the attribute once built. + model(model.dummy_inputs) + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + return None + + for model_class in self.all_model_classes: + for size in [config.vocab_size - 10, config.vocab_size + 10, None]: + # build the embeddings + model = model_class(config=config) + old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + old_final_logits_bias = model.get_bias() + + # reshape the embeddings + model.resize_token_embeddings(size) + new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + new_final_logits_bias = model.get_bias() + + # check that the resized embeddings size matches the desired size. + assert_size = size if size is not None else config.vocab_size + + self.assertEqual(new_input_embeddings.shape[0], assert_size) + + # check that weights remain the same after resizing + models_equal = True + for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_output_embeddings is not None and new_output_embeddings is not None: + self.assertEqual(new_output_embeddings.shape[0], assert_size) + + models_equal = True + for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_final_logits_bias is not None and new_final_logits_bias is not None: + old_final_logits_bias = old_final_logits_bias["final_logits_bias"] + new_final_logits_bias = new_final_logits_bias["final_logits_bias"] + self.assertEqual(new_final_logits_bias.shape[0], 1) + self.assertEqual(new_final_logits_bias.shape[1], assert_size) + + models_equal = True + for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()): + for p1, p2 in zip(old, new): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_tf_lxmert.py b/tests/test_modeling_tf_lxmert.py index e50151256158..eceff0c48efc 100644 --- a/tests/test_modeling_tf_lxmert.py +++ b/tests/test_modeling_tf_lxmert.py @@ -687,15 +687,17 @@ def test_model_common_attributes(self): assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) if model_class in list_lm_models: - x = model.get_output_layer_with_bias() + x = model.get_output_embeddings() assert isinstance(x, tf.keras.layers.Layer) - name = model.get_prefix_bias_name() - assert isinstance(name, str) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) else: - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() + x = model.get_output_embeddings() assert x is None + name = model.get_bias() + assert name is None def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI diff --git a/tests/test_modeling_tf_marian.py b/tests/test_modeling_tf_marian.py index 3be3d9832da3..7c3c87e0d3bd 100644 --- a/tests/test_modeling_tf_marian.py +++ b/tests/test_modeling_tf_marian.py @@ -38,7 +38,7 @@ class ModelTester(TFBartModelTester): @require_tf -class TestTFMarianCommon(TFModelTesterMixin, unittest.TestCase): +class TFMarianMTModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFMarianMTModel,) if is_tf_available() else () all_generative_model_classes = (TFMarianMTModel,) if is_tf_available() else () model_tester_cls = ModelTester @@ -56,13 +56,6 @@ def test_inputs_embeds(self): # inputs_embeds not supported pass - def test_saved_model_with_hidden_states_output(self): - # Should be uncommented during patrick TF refactor - pass - - def test_saved_model_with_attentions_output(self): - pass - def test_compile_tf_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -100,15 +93,87 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI pass + def test_resize_token_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _get_word_embedding_weight(model, embedding_layer): + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + # Here we build the word embeddings weights if not exists. + # And then we retry to get the attribute once built. + model(model.dummy_inputs) + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + return None + + for model_class in self.all_model_classes: + for size in [config.vocab_size - 10, config.vocab_size + 10, None]: + # build the embeddings + model = model_class(config=config) + old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + old_final_logits_bias = model.get_bias() + + # reshape the embeddings + model.resize_token_embeddings(size) + new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + new_final_logits_bias = model.get_bias() + + # check that the resized embeddings size matches the desired size. + assert_size = size if size is not None else config.vocab_size + + self.assertEqual(new_input_embeddings.shape[0], assert_size) + + # check that weights remain the same after resizing + models_equal = True + for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_output_embeddings is not None and new_output_embeddings is not None: + self.assertEqual(new_output_embeddings.shape[0], assert_size) + + models_equal = True + for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_final_logits_bias is not None and new_final_logits_bias is not None: + old_final_logits_bias = old_final_logits_bias["final_logits_bias"] + new_final_logits_bias = new_final_logits_bias["final_logits_bias"] + self.assertEqual(new_final_logits_bias.shape[0], 1) + self.assertEqual(new_final_logits_bias.shape[1], assert_size) + + models_equal = True + for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()): + for p1, p2 in zip(old, new): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + class AbstractMarianIntegrationTest(unittest.TestCase): maxDiff = 1000 # show more chars for failing integration tests diff --git a/tests/test_modeling_tf_mbart.py b/tests/test_modeling_tf_mbart.py index 5724bef50834..0916f19b111e 100644 --- a/tests/test_modeling_tf_mbart.py +++ b/tests/test_modeling_tf_mbart.py @@ -36,7 +36,7 @@ class ModelTester(TFBartModelTester): @require_tf -class TestTFMBartCommon(TFModelTesterMixin, unittest.TestCase): +class TFMBartModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFMBartForConditionalGeneration,) if is_tf_available() else () model_tester_cls = ModelTester @@ -54,14 +54,6 @@ def test_inputs_embeds(self): # inputs_embeds not supported pass - def test_saved_model_with_hidden_states_output(self): - # Should be uncommented during patrick TF refactor - pass - - def test_saved_model_with_attentions_output(self): - # Should be uncommented during patrick TF refactor - pass - def test_compile_tf_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -99,15 +91,87 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI pass + def test_resize_token_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _get_word_embedding_weight(model, embedding_layer): + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + # Here we build the word embeddings weights if not exists. + # And then we retry to get the attribute once built. + model(model.dummy_inputs) + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + return None + + for model_class in self.all_model_classes: + for size in [config.vocab_size - 10, config.vocab_size + 10, None]: + # build the embeddings + model = model_class(config=config) + old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + old_final_logits_bias = model.get_bias() + + # reshape the embeddings + model.resize_token_embeddings(size) + new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + new_final_logits_bias = model.get_bias() + + # check that the resized embeddings size matches the desired size. + assert_size = size if size is not None else config.vocab_size + + self.assertEqual(new_input_embeddings.shape[0], assert_size) + + # check that weights remain the same after resizing + models_equal = True + for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_output_embeddings is not None and new_output_embeddings is not None: + self.assertEqual(new_output_embeddings.shape[0], assert_size) + + models_equal = True + for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_final_logits_bias is not None and new_final_logits_bias is not None: + old_final_logits_bias = old_final_logits_bias["final_logits_bias"] + new_final_logits_bias = new_final_logits_bias["final_logits_bias"] + self.assertEqual(new_final_logits_bias.shape[0], 1) + self.assertEqual(new_final_logits_bias.shape[1], assert_size) + + models_equal = True + for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()): + for p1, p2 in zip(old, new): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + @is_pt_tf_cross_test @require_sentencepiece diff --git a/tests/test_modeling_tf_mobilebert.py b/tests/test_modeling_tf_mobilebert.py index c090b1016276..8995e6cf9bb7 100644 --- a/tests/test_modeling_tf_mobilebert.py +++ b/tests/test_modeling_tf_mobilebert.py @@ -292,15 +292,17 @@ def test_model_common_attributes(self): assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) if model_class in list_lm_models: - x = model.get_output_layer_with_bias() + x = model.get_output_embeddings() assert isinstance(x, tf.keras.layers.Layer) - name = model.get_prefix_bias_name() - assert isinstance(name, str) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) else: - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() + x = model.get_output_embeddings() assert x is None + name = model.get_bias() + assert name is None def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI diff --git a/tests/test_modeling_tf_openai.py b/tests/test_modeling_tf_openai.py index 990a417ee364..a7ffe33c5225 100644 --- a/tests/test_modeling_tf_openai.py +++ b/tests/test_modeling_tf_openai.py @@ -228,10 +228,17 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert name is None + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None def test_openai_gpt_sequence_classification_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/tests/test_modeling_tf_pegasus.py b/tests/test_modeling_tf_pegasus.py index d54aa26ae373..5c91f35af294 100644 --- a/tests/test_modeling_tf_pegasus.py +++ b/tests/test_modeling_tf_pegasus.py @@ -41,7 +41,7 @@ class ModelTester(TFBartModelTester): @require_tf -class TestTFPegasusCommon(TFModelTesterMixin, unittest.TestCase): +class TFPegasusModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else () all_generative_model_classes = (TFPegasusForConditionalGeneration,) if is_tf_available() else () model_tester_cls = ModelTester @@ -59,14 +59,6 @@ def test_inputs_embeds(self): # inputs_embeds not supported pass - def test_saved_model_with_hidden_states_output(self): - # Should be uncommented during patrick TF refactor - pass - - def test_saved_model_with_attentions_output(self): - # Should be uncommented during patrick TF refactor - pass - def test_compile_tf_model(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -104,15 +96,87 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert isinstance(name, dict) + for k, v in name.items(): + assert isinstance(v, tf.Variable) + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI pass + def test_resize_token_embeddings(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def _get_word_embedding_weight(model, embedding_layer): + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + # Here we build the word embeddings weights if not exists. + # And then we retry to get the attribute once built. + model(model.dummy_inputs) + if hasattr(embedding_layer, "weight"): + return embedding_layer.weight + else: + return None + + for model_class in self.all_model_classes: + for size in [config.vocab_size - 10, config.vocab_size + 10, None]: + # build the embeddings + model = model_class(config=config) + old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + old_final_logits_bias = model.get_bias() + + # reshape the embeddings + model.resize_token_embeddings(size) + new_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings()) + new_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings()) + new_final_logits_bias = model.get_bias() + + # check that the resized embeddings size matches the desired size. + assert_size = size if size is not None else config.vocab_size + + self.assertEqual(new_input_embeddings.shape[0], assert_size) + + # check that weights remain the same after resizing + models_equal = True + for p1, p2 in zip(old_input_embeddings.value(), new_input_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_output_embeddings is not None and new_output_embeddings is not None: + self.assertEqual(new_output_embeddings.shape[0], assert_size) + + models_equal = True + for p1, p2 in zip(old_output_embeddings.value(), new_output_embeddings.value()): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + + if old_final_logits_bias is not None and new_final_logits_bias is not None: + old_final_logits_bias = old_final_logits_bias["final_logits_bias"] + new_final_logits_bias = new_final_logits_bias["final_logits_bias"] + self.assertEqual(new_final_logits_bias.shape[0], 1) + self.assertEqual(new_final_logits_bias.shape[1], assert_size) + + models_equal = True + for old, new in zip(old_final_logits_bias.value(), new_final_logits_bias.value()): + for p1, p2 in zip(old, new): + if tf.math.reduce_sum(tf.math.abs(p1 - p2)) > 0: + models_equal = False + self.assertTrue(models_equal) + @is_pt_tf_cross_test @require_sentencepiece diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index a51ed2e63651..f4d0f2ce19d4 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -289,10 +289,17 @@ def test_model_common_attributes(self): for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + + if model_class in self.all_generative_model_classes: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert name is None + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None def test_saved_model_creation(self): # This test is too long (>30sec) and makes fail the CI diff --git a/tests/test_modeling_tf_transfo_xl.py b/tests/test_modeling_tf_transfo_xl.py index c92972362424..14d3e5a3acf7 100644 --- a/tests/test_modeling_tf_transfo_xl.py +++ b/tests/test_modeling_tf_transfo_xl.py @@ -187,14 +187,21 @@ def test_transfo_xl_sequence_classification_model(self): def test_model_common_attributes(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + list_other_models_with_output_ebd = [TFTransfoXLForSequenceClassification] for model_class in self.all_model_classes: model = model_class(config) assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer) - x = model.get_output_layer_with_bias() - assert x is None - name = model.get_prefix_bias_name() - assert name is None + if model_class in list_other_models_with_output_ebd: + x = model.get_output_embeddings() + assert isinstance(x, tf.keras.layers.Layer) + name = model.get_bias() + assert name is None + else: + x = model.get_output_embeddings() + assert x is None + name = model.get_bias() + assert name is None @slow def test_model_from_pretrained(self):