diff --git a/CHANGELOG.md b/CHANGELOG.md index b88644fd877..771dd7cbd5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Renamed module `allennlp.data.tokenizers.token` to `allennlp.data.tokenizers.token_class` to avoid [this bug](https://github.com/allenai/allennlp/issues/4819). +- `transformers` dependency updated to version 4.0.1. ### Fixed diff --git a/allennlp/nn/util.py b/allennlp/nn/util.py index 56793958e04..c593694aaed 100644 --- a/allennlp/nn/util.py +++ b/allennlp/nn/util.py @@ -1767,10 +1767,10 @@ def find_embedding_layer(model: torch.nn.Module) -> torch.nn.Module: """ # We'll look for a few special cases in a first pass, then fall back to just finding a # TextFieldEmbedder in a second pass if we didn't find a special case. - from transformers.modeling_gpt2 import GPT2Model - from transformers.modeling_bert import BertEmbeddings - from transformers.modeling_albert import AlbertEmbeddings - from transformers.modeling_roberta import RobertaEmbeddings + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + from transformers.models.bert.modeling_bert import BertEmbeddings + from transformers.models.albert.modeling_albert import AlbertEmbeddings + from transformers.models.roberta.modeling_roberta import RobertaEmbeddings from allennlp.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder from allennlp.modules.text_field_embedders.basic_text_field_embedder import ( BasicTextFieldEmbedder, diff --git a/setup.py b/setup.py index af23ce36d8f..0d479b4b6d5 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,8 @@ "scikit-learn", "scipy", "pytest", - "transformers>=3.4,<3.6", + "transformers>=4.0,<4.1", + "sentencepiece", "jsonpickle", "dataclasses;python_version<'3.7'", "filelock>=3.0,<3.1", diff --git a/tests/data/token_indexers/pretrained_transformer_indexer_test.py b/tests/data/token_indexers/pretrained_transformer_indexer_test.py index f15f6096a36..d817af9b392 100644 --- a/tests/data/token_indexers/pretrained_transformer_indexer_test.py +++ b/tests/data/token_indexers/pretrained_transformer_indexer_test.py @@ -99,7 +99,7 @@ def test_transformers_vocab_sizes(self, model_name): def test_transformers_vocabs_added_correctly(self): namespace, model_name = "tags", "roberta-base" - tokenizer = cached_transformers.get_tokenizer(model_name) + tokenizer = cached_transformers.get_tokenizer(model_name, use_fast=False) allennlp_tokenizer = PretrainedTransformerTokenizer(model_name) indexer = PretrainedTransformerIndexer(model_name=model_name, namespace=namespace) allennlp_tokens = allennlp_tokenizer.tokenize("AllenNLP is great!")