Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean TF Bert #9788

Merged
merged 7 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@

logger = logging.get_logger(__name__)

TFModelInputType = Union[
List[tf.Tensor], List[np.ndarray], Dict[str, tf.Tensor], Dict[str, np.ndarray], np.ndarray, tf.Tensor
]


class TFModelUtilsMixin:
"""
Expand Down
66 changes: 37 additions & 29 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple

import tensorflow as tf

Expand Down Expand Up @@ -82,16 +82,16 @@ def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float,
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)

super().build(input_shape=input_shape)
super().build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
Expand All @@ -101,14 +101,14 @@ def get_config(self):

return dict(list(base_config.items()) + list(config.items()))

def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])

return embeddings

Expand All @@ -122,16 +122,16 @@ def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: fl
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.token_type_embeddings = self.add_weight(
name="embeddings",
shape=[self.type_vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)

super().build(input_shape=input_shape)
super().build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"type_vocab_size": self.type_vocab_size,
"hidden_size": self.hidden_size,
Expand All @@ -141,15 +141,15 @@ def get_config(self):

return dict(list(base_config.items()) + list(config.items()))

def call(self, token_type_ids):
def call(self, token_type_ids: tf.Tensor) -> tf.Tensor:
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(token_type_ids), [self.hidden_size]], axis=0)
)

embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(token_type_ids.shape.as_list() + [self.hidden_size])

return embeddings

Expand All @@ -163,16 +163,16 @@ def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_r
self.hidden_size = hidden_size
self.initializer_range = initializer_range

def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)

super().build(input_shape)

def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"max_position_embeddings": self.max_position_embeddings,
"hidden_size": self.hidden_size,
Expand All @@ -182,8 +182,8 @@ def get_config(self):

return dict(list(base_config.items()) + list(config.items()))

def call(self, position_ids):
input_shape = shape_list(tensor=position_ids)
def call(self, position_ids: tf.Tensor) -> tf.Tensor:
input_shape = shape_list(position_ids)
position_embeddings = self.position_embeddings[: input_shape[1], :]

return tf.broadcast_to(input=position_embeddings, shape=input_shape)
Expand Down Expand Up @@ -218,7 +218,14 @@ def __init__(self, config, **kwargs):
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertEmbeddings.call
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
def call(
self,
input_ids: tf.Tensor,
position_ids: tf.Tensor,
token_type_ids: tf.Tensor,
inputs_embeds: tf.Tensor,
training: bool = False,
) -> tf.Tensor:
"""
Applies embedding based on inputs tensor.

Expand Down Expand Up @@ -876,7 +883,7 @@ def call(
return outputs

# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output):
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1102,7 +1109,7 @@ def call(
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1205,7 +1212,7 @@ def call(
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1307,7 +1314,7 @@ def call(
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1422,7 +1429,7 @@ def call(
)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down Expand Up @@ -1569,13 +1576,14 @@ def call(
}
]
)
def serving(self, inputs):
output = self.call(inputs)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving
def serving(self, inputs: Dict[str, tf.Tensor]):
output = self.call(input_ids=inputs)

return self.serving_output(output)

# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None

Expand Down
Loading