diff --git a/lit_nlp/examples/models/glue_models.py b/lit_nlp/examples/models/glue_models.py index ec3b35bb..466f3172 100644 --- a/lit_nlp/examples/models/glue_models.py +++ b/lit_nlp/examples/models/glue_models.py @@ -433,12 +433,7 @@ def predict_minibatch(self, inputs: Iterable[JsonDict]): # Gathers word embeddings from BERT model embedding layer using input ids # of the tokens. input_ids = encoded_input["input_ids"] - # TODO(b/236276775): Unify on the TFBertEmbeddings.weight API after - # transformers is updated to v4.25.1 (or newer). - if hasattr(self.model.bert.embeddings, "word_embeddings"): - word_embeddings = self.model.bert.embeddings.word_embeddings - else: - word_embeddings = self.model.bert.embeddings.weight + word_embeddings = self.model.bert.embeddings.weight # [batch_size, num_tokens, emb_size] input_embs = tf.gather(word_embeddings, input_ids)