Skip to content

Commit

Permalink
Clean TF Bert (#9788)
Browse files Browse the repository at this point in the history
* Start cleaning BERT

* Clean BERT and all those depends of it

* Fix attribute name

* Apply style

* Apply Sylvain's comments

* Apply Lysandre's comments

* remove unused import
  • Loading branch information
jplu authored Jan 27, 2021
1 parent f0329ea commit 4adbdce
Show file tree
Hide file tree
Showing 15 changed files with 1,295 additions and 1,059 deletions.
4 changes: 4 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
logger = logging.get_logger(__name__)
tf_logger = tf.get_logger()

TFModelInputType = Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
]


class TFModelUtilsMixin:
"""
Expand Down
66 changes: 37 additions & 29 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

import tensorflow as tf

Expand Down Expand Up @@ -82,16 +82,16 @@ def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float,
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)

super().build(input_shape=input_shape)
super().build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
Expand All @@ -101,14 +101,14 @@ def get_config(self):

return dict(list(base_config.items()) + list(config.items()))

def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])

return embeddings

Expand All @@ -122,16 +122,16 @@ def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: fl
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)

super().build(input_shape=input_shape)
super().build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
Expand All @@ -141,15 +141,15 @@ def get_config(self):

return dict(list(base_config.items()) + list(config.items()))

def call(self, token_type_ids):
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])

return embeddings

Expand All @@ -163,16 +163,16 @@ def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_r
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)

super().build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
Expand All @@ -182,8 +182,8 @@ def get_config(self):

return dict(list(base_config.items()) + list(config.items()))

def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]

return tf.broadcast_to(input=position_embeddings, shape=input_shape)
Expand Down Expand Up @@ -218,7 +218,14 @@ def __init__(self, config, **kwargs):
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
def call(
self,
input_ids: tf.Tensor,
position_ids: tf.Tensor,
token_type_ids: tf.Tensor,
inputs_embeds: tf.Tensor,
training: bool = False,
) -> tf.Tensor:
"""
Applies embedding based on inputs tensor.
Expand Down Expand Up @@ -879,7 +886,7 @@ def call(
return outputs

# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output):
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1105,7 +1112,7 @@ def call(
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1208,7 +1215,7 @@ def call(
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1310,7 +1317,7 @@ def call(
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1425,7 +1432,7 @@ def call(
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1572,13 +1579,14 @@ def call(
}
]
)
def serving(self, inputs):
output = self.call(inputs)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
def serving(self, inputs: Dict[str, tf.Tensor]):
output = self.call(input_ids=inputs)

return self.serving_output(output)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down
Loading

0 comments on commit 4adbdce

Please sign in to comment.