Transformers v4.0.0: Fast tokenizers, model outputs, file reorganization
Transformers v4.0.0-rc-1: Fast tokenizers, model outputs, file reorganization
Breaking changes since v3.x
Version v4.0.0 introduces several breaking changes that were necessary.
1. AutoTokenizers and pipelines now use fast (rust) tokenizers by default.
The python and rust tokenizers have roughly the same API, but the rust tokenizers have a more complete feature set. The main breaking change is the handling of overflowing tokens between the python and rust tokenizers.
How to obtain the same behavior as v3.x in v4.x
- The pipelines now contain additional features out of the box. See the token-classification pipeline with the
grouped_entities
flag. - The auto-tokenizers now return rust tokenizers. In order to obtain the python tokenizers instead, the user may use the
use_fast
flag by setting it toFalse
:
In version v3.x
:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("xxx")
to obtain the same in version v4.x
:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("xxx", use_fast=False)
2. SentencePiece is removed from the required dependencies
The requirement on the SentencePiece dependency has been lifted from the setup.py
. This is done so that we may have a channel on anaconda cloud without relying on conda-forge
. This means that the tokenizers that depend on the SentencePiece library will not be available with a standard transformers
installation.
This includes the slow versions of:
XLNetTokenizer
AlbertTokenizer
CamembertTokenizer
MBartTokenizer
PegasusTokenizer
T5Tokenizer
ReformerTokenizer
XLMRobertaTokenizer
How to obtain the same behavior as v3.x in v4.x
In order to obtain the same behavior as version v3.x
, you should install sentencepiece
additionally:
In version v3.x
:
pip install transformers
to obtain the same in version v4.x
:
pip install transformers[sentencepiece]
or
pip install transformers sentencepiece
3. The architecture of the repo has been updated so that each model resides in its folder
The past and foreseeable addition of new models means that the number of files in the directory src/transformers
keeps growing and becomes harder to navigate and understand. We made the choice to put each model and the files accompanying it in their own sub-directories.
This is a breaking change as importing intermediary layers using a model's module directly needs to be done via a different path.
How to obtain the same behavior as v3.x in v4.x
In order to obtain the same behavior as version v3.x
, you should update the path used to access the layers.
In version v3.x
:
from transformers.modeling_bert import BertLayer
to obtain the same in version v4.x
:
from transformers.models.bert.modeling_bert import BertLayer
4. Switching the return_dict
argument to True
by default
The return_dict
argument enables the return of named-tuples-like python objects containing the model outputs, instead of the standard tuples. This object is self-documented as keys can be used to retrieve values, while also behaving as a tuple as users may retrieve objects by index or by slice.
This is a breaking change as the limitation of that tuple is that it cannot be unpacked: value0, value1 = outputs
will not work.
How to obtain the same behavior as v3.x in v4.x
In order to obtain the same behavior as version v3.x
, you should specify the return_dict
argument to False
, either in the model configuration or during the forward pass.
In version v3.x
:
outputs = model(**inputs)
to obtain the same in version v4.x
:
outputs = model(**inputs, return_dict=False)
5. Removed some deprecated attributes
Attributes that were deprecated have been removed if they had been deprecated for at least a month. The full list of deprecated attributes can be found in #8604.
Here is a list of these attributes/methods/arguments and what their replacements should be:
In several models, the labels become consistent with the other models:
masked_lm_labels
becomeslabels
inAlbertForMaskedLM
andAlbertForPreTraining
.masked_lm_labels
becomeslabels
inBertForMaskedLM
andBertForPreTraining
.masked_lm_labels
becomeslabels
inDistilBertForMaskedLM
.masked_lm_labels
becomeslabels
inElectraForMaskedLM
.masked_lm_labels
becomeslabels
inLongformerForMaskedLM
.masked_lm_labels
becomeslabels
inMobileBertForMaskedLM
.masked_lm_labels
becomeslabels
inRobertaForMaskedLM
.lm_labels
becomeslabels
inBartForConditionalGeneration
.lm_labels
becomeslabels
inGPT2DoubleHeadsModel
.lm_labels
becomeslabels
inOpenAIGPTDoubleHeadsModel
.lm_labels
becomeslabels
inT5ForConditionalGeneration
.
In several models, the caching mechanism becomes consistent with the other models:
decoder_cached_states
becomespast_key_values
in all BART-like, FSMT and T5 models.decoder_past_key_values
becomespast_key_values
in all BART-like, FSMT and T5 models.past
becomespast_key_values
in all CTRL models.past
becomespast_key_values
in all GPT-2 models.
Regarding the tokenizer classes:
- The tokenizer attribute
max_len
becomesmodel_max_length
. - The tokenizer attribute
return_lengths
becomesreturn_length
. - The tokenizer encoding argument
is_pretokenized
becomesis_split_into_words
.
Regarding the Trainer
class:
- The
Trainer
argumenttb_writer
is removed in favor of the callbackTensorBoardCallback(tb_writer=...)
. - The
Trainer
argumentprediction_loss_only
is removed in favor of the class argumentargs.prediction_loss_only
. - The
Trainer
attributedata_collator
should be a callable. - The
Trainer
method_log
is deprecated in favor oflog
. - The
Trainer
method_training_step
is deprecated in favor oftraining_step
. - The
Trainer
method_prediction_loop
is deprecated in favor ofprediction_loop
. - The
Trainer
methodis_local_master
is deprecated in favor ofis_local_process_zero
. - The
Trainer
methodis_world_master
is deprecated in favor ofis_world_process_zero
.
Regarding the TFTrainer
class:
- The
TFTrainer
argumentprediction_loss_only
is removed in favor of the class argumentargs.prediction_loss_only
. - The
Trainer
method_log
is deprecated in favor oflog
. - The
TFTrainer
method_prediction_loop
is deprecated in favor ofprediction_loop
. - The
TFTrainer
method_setup_wandb
is deprecated in favor ofsetup_wandb
. - The
TFTrainer
method_run_model
is deprecated in favor ofrun_model
.
Regarding the TrainerArgument
and TFTrainerArgument
classes:
- The
TrainerArgument
argumentevaluate_during_training
is deprecated in favor ofevaluation_strategy
. - The
TFTrainerArgument
argumentevaluate_during_training
is deprecated in favor ofevaluation_strategy
.
Regarding the Transfo-XL model:
- The Transfo-XL configuration attribute
tie_weight
becomestie_words_embeddings
. - The Transfo-XL modeling method
reset_length
becomesreset_memory_length
.
Regarding pipelines:
- The
FillMaskPipeline
argumenttopk
becomestop_k
.
Model Templates
Version 4.0.0 will be the first to include the experimental feature of model templates. These model templates aim to facilitate the addition of new models to the library by doing most of the work: generating the model/configuration/tokenization/test files that fit the API, with respect to the choice the user has made in terms of naming and functionality.
This release includes a model template for the encoder model (similar to the BERT architecture). Generating a model using the template will generate the files, put them at the appropriate location, reference them throughout the code-base, and generate a working test suite. The user should then only modify the files to their liking, rather than creating the model from scratch.
Feedback welcome, get started from the README here.
- Model templates encoder only #8509 (@LysandreJik)
New model additions
mT5 and T5 version 1.1 (@patrickvonplaten )
The T5v1.1 is an improved version of the original T5 model, see here: https://github.com/google-research/text-to-text-transfer-transformer/blob/master/released_checkpoints.md
The multilingual T5 model (mT5) was presented in https://arxiv.org/abs/2010.11934 and is based on the T5v1.1 architecture.
Multiple pre-trained checkpoints have been added to the library:
Relevant pull requests:
- T5 & mT5 #8552 (@patrickvonplaten)
- [MT5] More docs #8589 (@patrickvonplaten)
- Fix init for MT5 #8591 (@sgugger)
TF DPR
The DPR model has been added in TensorFlow to match its PyTorch counterpart by @ratthachat
- Add TFDPR #8203 (@ratthachat)
TF Longformer
Additional heads have been added to the TensorFlow Longformer implementation: SequenceClassification, MultipleChoice and TokenClassification
- Tf longformer for sequence classification #8231 (@elk-cloner)
Bug fixes and improvements
- [s2s/distill] hparams.tokenizer_name = hparams.teacher #8382 (@ShichaoSun)
- [examples] better PL version check #8429 (@stas00)
- Question template #8440 (@sgugger)
- [docs] improve bart/marian/mBART/pegasus docs #8421 (@sshleifer)
- Add auto next sentence prediction #8432 (@jplu)
- Windows dev section in the contributing file #8436 (@jplu)
- [testing utils] get_auto_remove_tmp_dir more intuitive behavior #8401 (@stas00)
- Add missing import #8444 (@jplu)
- [T5 Tokenizer] Fix t5 special tokens #8435 (@patrickvonplaten)
- using multi_gpu consistently #8446 (@stas00)
- Add missing tasks to
pipeline
docstring #8428 (@bryant1410) - [No merge] TF integration testing #7621 (@LysandreJik)
- [T5Tokenizer] fix t5 token type ids #8437 (@patrickvonplaten)
- Bug fix for apply_chunking_to_forward chunking dimension check #8391 (@pedrocolon93)
- Fix TF Longformer #8460 (@jplu)
- Add next sentence prediction loss computation #8462 (@jplu)
- Fix TF next sentence output #8466 (@jplu)
- Example NER script predicts on tokenized dataset #8468 (@sarnoult)
- Replaced unnecessary iadd operations on lists in tokenization_utils.py with proper list methods #8433 (@bombs-kim)
- Flax/Jax documentation #8331 (@mfuntowicz)
- [s2s] distill t5-large -> t5-small #8376 (@sbhaktha)
- Update deploy-docs dependencies on CI to enable Flax #8475 (@mfuntowicz)
- Fix on "examples/language-modeling" to support more datasets #8474 (@zeyuyun1)
- Fix doc bug #8500 (@mymusise)
- Model sharing doc #8498 (@sgugger)
- Fix SqueezeBERT for masked language model #8479 (@forresti)
- Fix logging in the examples #8458 (@jplu)
- Fix check scripts for Windows #8491 (@jplu)
- Add pretraining loss computation for TF Bert pretraining #8470 (@jplu)
- [T5] Bug correction & Refactor #8518 (@patrickvonplaten)
- Model sharing doc: more tweaks #8520 (@julien-c)
- [T5] Fix load weights function #8528 (@patrickvonplaten)
- Rework some TF tests #8492 (@jplu)
- [breaking|pipelines|tokenizers] Adding slow-fast tokenizers equivalence tests pipelines - Removing sentencepiece as a required dependency #8073 (@thomwolf)
- Adding the prepare_seq2seq_batch function to ProphetNet #8515 (@forest1988)
- Update version to v4.0.0-dev #8568 (@sgugger)
- TAPAS tokenizer & tokenizer tests #8482 (@LysandreJik)
- Switch
return_dict
toTrue
by default. #8530 (@sgugger) - Fix mixed precision issue for GPT2 #8572 (@jplu)
- Reorganize repo #8580 (@sgugger)
- Tokenizers: ability to load from model subfolder #8586 (@julien-c)
- Fix model templates #8595 (@sgugger)
- [examples tests] tests that are fine on multi-gpu #8582 (@stas00)
- Fix check repo utils #8600 (@sgugger)
- Tokenizers should be framework agnostic #8599 (@LysandreJik)
- Remove deprecated #8604 (@sgugger)
- Fixed link to the wrong paper. #8607 (@cronoik)
- Reset loss to zero on logging in Trainer to avoid bfloat16 issues #8561 (@bminixhofer)
- Fix DataCollatorForLanguageModeling #8621 (@sgugger)
- [s2s] multigpu skip #8613 (@stas00)
- [s2s] fix finetune.py to adjust for #8530 changes #8612 (@stas00)
- tf_bart typo - self.self.activation_dropout #8611 (@ratthachat)
- New TF loading weights #8490 (@jplu)
- Adding PrefixConstrainedLogitsProcessor #8529 (@nicola-decao)
- [Tokenizer Doc] Improve tokenizer summary #8622 (@patrickvonplaten)
- Fixes the training resuming with gradient accumulation #8624 (@sgugger)
- Fix training from scratch in new scripts #8623 (@sgugger)
- [s2s] distillation apex breaks return_dict obj #8631 (@stas00)
- Updated the Extractive Question Answering code snippets #8636 (@cronoik)
- Fix missing return_dict in RAG example to use a custom knowledge source #8653 (@lhoestq)
- Fix a bunch of slow tests #8634 (@LysandreJik)
- Better filtering of the model outputs in Trainer #8633 (@sgugger)
- Add sentencepiece to the CI and fix tests #8672 (@sgugger)
- Document adam betas TrainingArguments #8688 (@sgugger)
- Change default cache path #8734 (@sgugger)
- consistent ignore keys + make private #8737 (@stas00)
- Update TF BERT test & TF BERT test update (@LysandreJik)
- Fix slow tests v2 #8746 (@LysandreJik)
- MT5 should have an autotokenizer #8743 (@LysandreJik)
- Fix QA argument handler #8765 (@LysandreJik & @Narsil)
- Fix dpr<>bart config for RAG #8808 (@patrickvonplaten)
- [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes #8791 (@KristianHolsheimer)
- [Flax Test] Add require pytorch to flix flax test #8816 (@patrickvonplaten)
- Big model table #8774 (@sgugger)
- fix mt5 config #8832 (@patrickvonplaten)
- Migration guide from v3.x to v4.x #8763 (@LysandreJik)
- add xlnet mems and fix merge conflicts (@patrickvonplaten)
- Add a direct link to the big table #8850 (@sgugger)
- Remove deprecated
evaluate_during_training
#8852 (@sgugger)