diff --git a/README.md b/README.md index 746e789c8fbe..c0d5fd032b31 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[BERT For Sequence Generation](https://huggingface.co/transformers/model_doc/bertgeneration.html)** (from Google) released with the paper [Leveraging Pre-trained Checkpoints for Sequence Generation Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. 1. **[Blenderbot](https://huggingface.co/transformers/model_doc/blenderbot.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. 1. **[BlenderbotSmall](https://huggingface.co/transformers/model_doc/blenderbot_small.html)** (from Facebook) released with the paper [Recipes for building an open-domain chatbot](https://arxiv.org/abs/2004.13637) by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. +1. **[BORT](https://huggingface.co/transformers/model_doc/bort.html)** (from Alexa) released with the paper [Optimal Subarchitecture Extraction For BERT](https://arxiv.org/abs/2010.10499) by Adrian de Wynter and Daniel J. Perry. 1. **[CamemBERT](https://huggingface.co/transformers/model_doc/camembert.html)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. 1. **[ConvBERT](https://huggingface.co/transformers/model_doc/convbert.html)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan. 1. **[CTRL](https://huggingface.co/transformers/model_doc/ctrl.html)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. diff --git a/docs/source/index.rst b/docs/source/index.rst index 3ac7e8c2afd2..38a66e12e55e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -103,103 +103,105 @@ and conversion utilities for the following models: 7. :doc:`BlenderbotSmall ` (from Facebook) released with the paper `Recipes for building an open-domain chatbot `__ by Stephen Roller, Emily Dinan, Naman Goyal, Da Ju, Mary Williamson, Yinhan Liu, Jing Xu, Myle Ott, Kurt Shuster, Eric M. Smith, Y-Lan Boureau, Jason Weston. -8. :doc:`CamemBERT ` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty +8. :doc:`BORT ` (from Alexa) released with the paper `Optimal Subarchitecture Extraction For BERT + `__ by Adrian de Wynter and Daniel J. Perry. +9. :doc:`CamemBERT ` (from Inria/Facebook/Sorbonne) released with the paper `CamemBERT: a Tasty French Language Model `__ by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. -9. :doc:`ConvBERT ` (from YituTech) released with the paper `ConvBERT: Improving BERT with - Span-based Dynamic Convolution `__ by Zihang Jiang, Weihao Yu, Daquan Zhou, - Yunpeng Chen, Jiashi Feng, Shuicheng Yan. -10. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language +10. :doc:`ConvBERT ` (from YituTech) released with the paper `ConvBERT: Improving BERT with + Span-based Dynamic Convolution `__ by Zihang Jiang, Weihao Yu, Daquan Zhou, + Yunpeng Chen, Jiashi Feng, Shuicheng Yan. +11. :doc:`CTRL ` (from Salesforce) released with the paper `CTRL: A Conditional Transformer Language Model for Controllable Generation `__ by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. -11. :doc:`DeBERTa ` (from Microsoft Research) released with the paper `DeBERTa: Decoding-enhanced +12. :doc:`DeBERTa ` (from Microsoft Research) released with the paper `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. -12. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale +13. :doc:`DialoGPT ` (from Microsoft Research) released with the paper `DialoGPT: Large-Scale Generative Pre-training for Conversational Response Generation `__ by Yizhe Zhang, Siqi Sun, Michel Galley, Yen-Chun Chen, Chris Brockett, Xiang Gao, Jianfeng Gao, Jingjing Liu, Bill Dolan. -13. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a +14. :doc:`DistilBERT ` (from HuggingFace), released together with the paper `DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter `__ by Victor Sanh, Lysandre Debut and Thomas Wolf. The same method has been applied to compress GPT2 into `DistilGPT2 `__, RoBERTa into `DistilRoBERTa `__, Multilingual BERT into `DistilmBERT `__ and a German version of DistilBERT. -14. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain +15. :doc:`DPR ` (from Facebook) released with the paper `Dense Passage Retrieval for Open-Domain Question Answering `__ by Vladimir Karpukhin, Barlas Oğuz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. -15. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: +16. :doc:`ELECTRA ` (from Google Research/Stanford University) released with the paper `ELECTRA: Pre-training text encoders as discriminators rather than generators `__ by Kevin Clark, Minh-Thang Luong, Quoc V. Le, Christopher D. Manning. -16. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model +17. :doc:`FlauBERT ` (from CNRS) released with the paper `FlauBERT: Unsupervised Language Model Pre-training for French `__ by Hang Le, Loïc Vial, Jibril Frej, Vincent Segonne, Maximin Coavoux, Benjamin Lecouteux, Alexandre Allauzen, Benoît Crabbé, Laurent Besacier, Didier Schwab. -17. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: +18. :doc:`Funnel Transformer ` (from CMU/Google Brain) released with the paper `Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing `__ by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. -18. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative +19. :doc:`GPT ` (from OpenAI) released with the paper `Improving Language Understanding by Generative Pre-Training `__ by Alec Radford, Karthik Narasimhan, Tim Salimans and Ilya Sutskever. -19. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask +20. :doc:`GPT-2 ` (from OpenAI) released with the paper `Language Models are Unsupervised Multitask Learners `__ by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**. -20. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training +21. :doc:`LayoutLM ` (from Microsoft Research Asia) released with the paper `LayoutLM: Pre-training of Text and Layout for Document Image Understanding `__ by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei, Ming Zhou. -21. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer +22. :doc:`LED ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -22. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document +23. :doc:`Longformer ` (from AllenAI) released with the paper `Longformer: The Long-Document Transformer `__ by Iz Beltagy, Matthew E. Peters, Arman Cohan. -23. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality +24. :doc:`LXMERT ` (from UNC Chapel Hill) released with the paper `LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering `__ by Hao Tan and Mohit Bansal. -24. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by +25. :doc:`MarianMT ` Machine translation models trained using `OPUS `__ data by Jörg Tiedemann. The `Marian Framework `__ is being developed by the Microsoft Translator Team. -25. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for +26. :doc:`MBart ` (from Facebook) released with the paper `Multilingual Denoising Pre-training for Neural Machine Translation `__ by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. -26. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted +27. :doc:`MPNet ` (from Microsoft Research) released with the paper `MPNet: Masked and Permuted Pre-training for Language Understanding `__ by Kaitao Song, Xu Tan, Tao Qin, Jianfeng Lu, Tie-Yan Liu. -27. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained +28. :doc:`MT5 ` (from Google AI) released with the paper `mT5: A massively multilingual pre-trained text-to-text transformer `__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. -28. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted +29. :doc:`Pegasus ` (from Google) released with the paper `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive Summarization `__> by Jingqing Zhang, Yao Zhao, Mohammad Saleh and Peter J. Liu. -29. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting +30. :doc:`ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -30. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient +31. :doc:`Reformer ` (from Google Research) released with the paper `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya. -31. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT +32. :doc:`RoBERTa ` (from Facebook), released together with the paper a `Robustly Optimized BERT Pretraining Approach `__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. ultilingual BERT into `DistilmBERT `__ and a German version of DistilBERT. -32. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP +33. :doc:`SqueezeBert ` released with the paper `SqueezeBERT: What can computer vision teach NLP about efficient neural networks? `__ by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer. -33. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a +34. :doc:`T5 ` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer `__ by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu. -34. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via +35. :doc:`TAPAS ` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via Pre-training `__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos. -35. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: +36. :doc:`Transformer-XL ` (from Google/CMU) released with the paper `Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context `__ by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov. -36. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model +37. :doc:`XLM ` (from Facebook) released together with the paper `Cross-lingual Language Model Pretraining `__ by Guillaume Lample and Alexis Conneau. -37. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: +38. :doc:`XLM-ProphetNet ` (from Microsoft Research) released with the paper `ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training `__ by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou. -38. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised +39. :doc:`XLM-RoBERTa ` (from Facebook AI), released together with the paper `Unsupervised Cross-lingual Representation Learning at Scale `__ by Alexis Conneau*, Kartikay Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke Zettlemoyer and Veselin Stoyanov. -39. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive +40. :doc:`XLNet ` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive Pretraining for Language Understanding `__ by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le. @@ -375,6 +377,7 @@ TensorFlow and/or Flax. model_doc/bertgeneration model_doc/blenderbot model_doc/blenderbot_small + model_doc/bort model_doc/camembert model_doc/convbert model_doc/ctrl diff --git a/docs/source/model_doc/bort.rst b/docs/source/model_doc/bort.rst new file mode 100644 index 000000000000..14b5df79c1fb --- /dev/null +++ b/docs/source/model_doc/bort.rst @@ -0,0 +1,46 @@ +.. + Copyright 2020 The HuggingFace Team. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + specific language governing permissions and limitations under the License. + +BORT +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The BORT model was proposed in `Optimal Subarchitecture Extraction for BERT `__ by +Adrian de Wynter and Daniel J. Perry. It is an optimal subset of architectural parameters for the BERT, which the +authors refer to as "Bort". + +The abstract from the paper is the following: + +*We extract an optimal subset of architectural parameters for the BERT architecture from Devlin et al. (2018) by +applying recent breakthroughs in algorithms for neural architecture search. This optimal subset, which we refer to as +"Bort", is demonstrably smaller, having an effective (that is, not counting the embedding layer) size of 5.5% the +original BERT-large architecture, and 16% of the net size. Bort is also able to be pretrained in 288 GPU hours, which +is 1.2% of the time required to pretrain the highest-performing BERT parametric architectural variant, RoBERTa-large +(Liu et al., 2019), and about 33% of that of the world-record, in GPU hours, required to train BERT-large on the same +hardware. It is also 7.9x faster on a CPU, as well as being better performing than other compressed variants of the +architecture, and some of the non-compressed variants: it obtains performance improvements of between 0.3% and 31%, +absolute, with respect to BERT-large, on multiple public natural language understanding (NLU) benchmarks.* + +Tips: + +- BORT's model architecture is based on BERT, so one can refer to :doc:`BERT's documentation page ` for the + model's API as well as usage examples. +- BORT uses the RoBERTa tokenizer instead of the BERT tokenizer, so one can refer to :doc:`RoBERTa's documentation page + ` for the tokenizer's API as well as usage examples. +- BORT requires a specific fine-tuning algorithm, called `Agora + `__ , + that is sadly not open-sourced yet. It would be very useful for the community, if someone tries to implement the + algorithm to make BORT fine-tuning work. + +The original code can be found `here `__. diff --git a/docs/source/model_doc/dialogpt.rst b/docs/source/model_doc/dialogpt.rst index f821292d9406..a7a09b370465 100644 --- a/docs/source/model_doc/dialogpt.rst +++ b/docs/source/model_doc/dialogpt.rst @@ -48,7 +48,6 @@ modeling. We first concatenate all dialog turns within a dialogue session into a sequence length), ended by the end-of-text token.* For more information please confer to the original paper. -DialoGPT's architecture is based on the GPT2 model, so one can refer to GPT2's `docstring -`_. +DialoGPT's architecture is based on the GPT2 model, so one can refer to :doc:`GPT2's documentation page `. The original code can be found `here `_. diff --git a/src/transformers/models/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py b/src/transformers/models/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py new file mode 100644 index 000000000000..acc6981d2bee --- /dev/null +++ b/src/transformers/models/bort/convert_bort_original_gluonnlp_checkpoint_to_pytorch.py @@ -0,0 +1,318 @@ +# coding=utf-8 +# Copyright 2020, The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Bort checkpoint.""" + + +import argparse +import os + +import numpy as np +import torch +from packaging import version + +import gluonnlp as nlp +import mxnet as mx +from gluonnlp.base import get_home_dir +from gluonnlp.model.bert import BERTEncoder +from gluonnlp.model.utils import _load_vocab +from gluonnlp.vocab import Vocab +from transformers import BertConfig, BertForMaskedLM, BertModel, RobertaTokenizer +from transformers.models.bert.modeling_bert import ( + BertIntermediate, + BertLayer, + BertOutput, + BertSelfAttention, + BertSelfOutput, +) +from transformers.utils import logging + + +if version.parse(nlp.__version__) != version.parse("0.8.3"): + raise Exception("requires gluonnlp == 0.8.3") + +if version.parse(mx.__version__) != version.parse("1.5.0"): + raise Exception("requires mxnet == 1.5.0") + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +SAMPLE_TEXT = "The Nymphenburg Palace is a beautiful palace in Munich!" + + +def convert_bort_checkpoint_to_pytorch(bort_checkpoint_path: str, pytorch_dump_folder_path: str): + """ + Convert the original Bort checkpoint (based on MXNET and Gluonnlp) to our BERT structure- + """ + + # Original Bort configuration + bort_4_8_768_1024_hparams = { + "attention_cell": "multi_head", + "num_layers": 4, + "units": 1024, + "hidden_size": 768, + "max_length": 512, + "num_heads": 8, + "scaled": True, + "dropout": 0.1, + "use_residual": True, + "embed_size": 1024, + "embed_dropout": 0.1, + "word_embed": None, + "layer_norm_eps": 1e-5, + "token_type_vocab_size": 2, + } + + predefined_args = bort_4_8_768_1024_hparams + + # Let's construct the original Bort model here + # Taken from official BERT implementation, see: + # https://github.com/alexa/bort/blob/master/bort/bort.py + encoder = BERTEncoder( + attention_cell=predefined_args["attention_cell"], + num_layers=predefined_args["num_layers"], + units=predefined_args["units"], + hidden_size=predefined_args["hidden_size"], + max_length=predefined_args["max_length"], + num_heads=predefined_args["num_heads"], + scaled=predefined_args["scaled"], + dropout=predefined_args["dropout"], + output_attention=False, + output_all_encodings=False, + use_residual=predefined_args["use_residual"], + activation=predefined_args.get("activation", "gelu"), + layer_norm_eps=predefined_args.get("layer_norm_eps", None), + ) + + # Vocab information needs to be fetched first + # It's the same as RoBERTa, so RobertaTokenizer can be used later + vocab_name = "openwebtext_ccnews_stories_books_cased" + + # Specify download folder to Gluonnlp's vocab + gluon_cache_dir = os.path.join(get_home_dir(), "models") + bort_vocab = _load_vocab(vocab_name, None, gluon_cache_dir, cls=Vocab) + + original_bort = nlp.model.BERTModel( + encoder, + len(bort_vocab), + units=predefined_args["units"], + embed_size=predefined_args["embed_size"], + embed_dropout=predefined_args["embed_dropout"], + word_embed=predefined_args["word_embed"], + use_pooler=False, + use_token_type_embed=False, + token_type_vocab_size=predefined_args["token_type_vocab_size"], + use_classifier=False, + use_decoder=False, + ) + + original_bort.load_parameters(bort_checkpoint_path, cast_dtype=True, ignore_extra=True) + params = original_bort._collect_params_with_prefix() + + # Build our config 🤗 + hf_bort_config_json = { + "architectures": ["BertForMaskedLM"], + "attention_probs_dropout_prob": predefined_args["dropout"], + "hidden_act": "gelu", + "hidden_dropout_prob": predefined_args["dropout"], + "hidden_size": predefined_args["embed_size"], + "initializer_range": 0.02, + "intermediate_size": predefined_args["hidden_size"], + "layer_norm_eps": predefined_args["layer_norm_eps"], + "max_position_embeddings": predefined_args["max_length"], + "model_type": "bort", + "num_attention_heads": predefined_args["num_heads"], + "num_hidden_layers": predefined_args["num_layers"], + "pad_token_id": 1, # 2 = BERT, 1 = RoBERTa + "type_vocab_size": 1, # 2 = BERT, 1 = RoBERTa + "vocab_size": len(bort_vocab), + } + + hf_bort_config = BertConfig.from_dict(hf_bort_config_json) + hf_bort_model = BertForMaskedLM(hf_bort_config) + hf_bort_model.eval() + + # Parameter mapping table (Gluonnlp to Transformers) + # * denotes layer index + # + # | Gluon Parameter | Transformers Parameter + # | -------------------------------------------------------------- | ---------------------- + # | `encoder.layer_norm.beta` | `bert.embeddings.LayerNorm.bias` + # | `encoder.layer_norm.gamma` | `bert.embeddings.LayerNorm.weight` + # | `encoder.position_weight` | `bert.embeddings.position_embeddings.weight` + # | `word_embed.0.weight` | `bert.embeddings.word_embeddings.weight` + # | `encoder.transformer_cells.*.attention_cell.proj_key.bias` | `bert.encoder.layer.*.attention.self.key.bias` + # | `encoder.transformer_cells.*.attention_cell.proj_key.weight` | `bert.encoder.layer.*.attention.self.key.weight` + # | `encoder.transformer_cells.*.attention_cell.proj_query.bias` | `bert.encoder.layer.*.attention.self.query.bias` + # | `encoder.transformer_cells.*.attention_cell.proj_query.weight` | `bert.encoder.layer.*.attention.self.query.weight` + # | `encoder.transformer_cells.*.attention_cell.proj_value.bias` | `bert.encoder.layer.*.attention.self.value.bias` + # | `encoder.transformer_cells.*.attention_cell.proj_value.weight` | `bert.encoder.layer.*.attention.self.value.weight` + # | `encoder.transformer_cells.*.ffn.ffn_2.bias` | `bert.encoder.layer.*.attention.output.dense.bias` + # | `encoder.transformer_cells.*.ffn.ffn_2.weight` | `bert.encoder.layer.*.attention.output.dense.weight` + # | `encoder.transformer_cells.*.layer_norm.beta` | `bert.encoder.layer.*.attention.output.LayerNorm.bias` + # | `encoder.transformer_cells.*.layer_norm.gamma` | `bert.encoder.layer.*.attention.output.LayerNorm.weight` + # | `encoder.transformer_cells.*.ffn.ffn_1.bias` | `bert.encoder.layer.*.intermediate.dense.bias` + # | `encoder.transformer_cells.*.ffn.ffn_1.weight` | `bert.encoder.layer.*.intermediate.dense.weight` + # | `encoder.transformer_cells.*.ffn.layer_norm.beta` | `bert.encoder.layer.*.output.LayerNorm.bias` + # | `encoder.transformer_cells.*.ffn.layer_norm.gamma` | `bert.encoder.layer.*.output.LayerNorm.weight` + # | `encoder.transformer_cells.*.proj.bias` | `bert.encoder.layer.*.output.dense.bias` + # | `encoder.transformer_cells.*.proj.weight` | `bert.encoder.layer.*.output.dense.weight` + + # Helper function to convert MXNET Arrays to PyTorch + def to_torch(mx_array) -> torch.nn.Parameter: + return torch.nn.Parameter(torch.FloatTensor(mx_array.data().asnumpy())) + + # Check param shapes and map new HF param back + def check_and_map_params(hf_param, gluon_param): + shape_hf = hf_param.shape + + gluon_param = to_torch(params[gluon_param]) + shape_gluon = gluon_param.shape + + assert ( + shape_hf == shape_gluon + ), f"The gluon parameter {gluon_param} has shape {shape_gluon}, but expects shape {shape_hf} for Transformers" + + return gluon_param + + hf_bort_model.bert.embeddings.word_embeddings.weight = check_and_map_params( + hf_bort_model.bert.embeddings.word_embeddings.weight, "word_embed.0.weight" + ) + hf_bort_model.bert.embeddings.position_embeddings.weight = check_and_map_params( + hf_bort_model.bert.embeddings.position_embeddings.weight, "encoder.position_weight" + ) + hf_bort_model.bert.embeddings.LayerNorm.bias = check_and_map_params( + hf_bort_model.bert.embeddings.LayerNorm.bias, "encoder.layer_norm.beta" + ) + hf_bort_model.bert.embeddings.LayerNorm.weight = check_and_map_params( + hf_bort_model.bert.embeddings.LayerNorm.weight, "encoder.layer_norm.gamma" + ) + + # Inspired by RoBERTa conversion script, we just zero them out (Bort does not use them) + hf_bort_model.bert.embeddings.token_type_embeddings.weight.data = torch.zeros_like( + hf_bort_model.bert.embeddings.token_type_embeddings.weight.data + ) + + for i in range(hf_bort_config.num_hidden_layers): + layer: BertLayer = hf_bort_model.bert.encoder.layer[i] + + # self attention + self_attn: BertSelfAttention = layer.attention.self + + self_attn.key.bias.data = check_and_map_params( + self_attn.key.bias.data, f"encoder.transformer_cells.{i}.attention_cell.proj_key.bias" + ) + + self_attn.key.weight.data = check_and_map_params( + self_attn.key.weight.data, f"encoder.transformer_cells.{i}.attention_cell.proj_key.weight" + ) + self_attn.query.bias.data = check_and_map_params( + self_attn.query.bias.data, f"encoder.transformer_cells.{i}.attention_cell.proj_query.bias" + ) + self_attn.query.weight.data = check_and_map_params( + self_attn.query.weight.data, f"encoder.transformer_cells.{i}.attention_cell.proj_query.weight" + ) + self_attn.value.bias.data = check_and_map_params( + self_attn.value.bias.data, f"encoder.transformer_cells.{i}.attention_cell.proj_value.bias" + ) + self_attn.value.weight.data = check_and_map_params( + self_attn.value.weight.data, f"encoder.transformer_cells.{i}.attention_cell.proj_value.weight" + ) + + # self attention output + self_output: BertSelfOutput = layer.attention.output + + self_output.dense.bias = check_and_map_params( + self_output.dense.bias, f"encoder.transformer_cells.{i}.proj.bias" + ) + self_output.dense.weight = check_and_map_params( + self_output.dense.weight, f"encoder.transformer_cells.{i}.proj.weight" + ) + self_output.LayerNorm.bias = check_and_map_params( + self_output.LayerNorm.bias, f"encoder.transformer_cells.{i}.layer_norm.beta" + ) + self_output.LayerNorm.weight = check_and_map_params( + self_output.LayerNorm.weight, f"encoder.transformer_cells.{i}.layer_norm.gamma" + ) + + # intermediate + intermediate: BertIntermediate = layer.intermediate + + intermediate.dense.bias = check_and_map_params( + intermediate.dense.bias, f"encoder.transformer_cells.{i}.ffn.ffn_1.bias" + ) + intermediate.dense.weight = check_and_map_params( + intermediate.dense.weight, f"encoder.transformer_cells.{i}.ffn.ffn_1.weight" + ) + + # output + bert_output: BertOutput = layer.output + + bert_output.dense.bias = check_and_map_params( + bert_output.dense.bias, f"encoder.transformer_cells.{i}.ffn.ffn_2.bias" + ) + bert_output.dense.weight = check_and_map_params( + bert_output.dense.weight, f"encoder.transformer_cells.{i}.ffn.ffn_2.weight" + ) + bert_output.LayerNorm.bias = check_and_map_params( + bert_output.LayerNorm.bias, f"encoder.transformer_cells.{i}.ffn.layer_norm.beta" + ) + bert_output.LayerNorm.weight = check_and_map_params( + bert_output.LayerNorm.weight, f"encoder.transformer_cells.{i}.ffn.layer_norm.gamma" + ) + + # Save space and energy 🎄 + hf_bort_model.half() + + # Compare output of both models + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + + input_ids = tokenizer.encode_plus(SAMPLE_TEXT)["input_ids"] + + # Get gluon output + gluon_input_ids = mx.nd.array([input_ids]) + output_gluon = original_bort(inputs=gluon_input_ids, token_types=[]) + + # Get Transformer output (save and reload model again) + hf_bort_model.save_pretrained(pytorch_dump_folder_path) + hf_bort_model = BertModel.from_pretrained(pytorch_dump_folder_path) + hf_bort_model.eval() + + input_ids = tokenizer.encode_plus(SAMPLE_TEXT, return_tensors="pt") + output_hf = hf_bort_model(**input_ids)[0] + + gluon_layer = output_gluon[0].asnumpy() + hf_layer = output_hf[0].detach().numpy() + + max_absolute_diff = np.max(np.abs(hf_layer - gluon_layer)).item() + success = np.allclose(gluon_layer, hf_layer, atol=1e-3) + + if success: + print("✔️ Both model do output the same tensors") + else: + print("❌ Both model do **NOT** output the same tensors") + print("Absolute difference is:", max_absolute_diff) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--bort_checkpoint_path", default=None, type=str, required=True, help="Path the official Bort params file." + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, required=True, help="Path to the output PyTorch model." + ) + args = parser.parse_args() + convert_bort_checkpoint_to_pytorch(args.bort_checkpoint_path, args.pytorch_dump_folder_path) diff --git a/tests/test_modeling_bort.py b/tests/test_modeling_bort.py new file mode 100644 index 000000000000..79ca94080107 --- /dev/null +++ b/tests/test_modeling_bort.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + + +if is_torch_available(): + import torch + + from transformers import AutoModel + + +@require_torch +@require_sentencepiece +@require_tokenizers +class BortIntegrationTest(unittest.TestCase): + @slow + def test_output_embeds_base_model(self): + model = AutoModel.from_pretrained("amazon/bort") + model.to(torch_device) + + input_ids = torch.tensor( + [[0, 18077, 4082, 7804, 8606, 6195, 2457, 3321, 11, 10489, 16, 269, 2579, 328, 2]], + device=torch_device, + dtype=torch.long, + ) # Schloß Nymphenburg in Munich is really nice! + output = model(input_ids)["last_hidden_state"] + expected_shape = torch.Size((1, 15, 1024)) + self.assertEqual(output.shape, expected_shape) + # compare the actual values for a slice. + expected_slice = torch.tensor( + [[[-0.0349, 0.0436, -1.8654], [-0.6964, 0.0835, -1.7393], [-0.9819, 0.2956, -0.2868]]], + device=torch_device, + dtype=torch.float, + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) diff --git a/tests/test_modeling_tf_bort.py b/tests/test_modeling_tf_bort.py new file mode 100644 index 000000000000..8053afbd30cf --- /dev/null +++ b/tests/test_modeling_tf_bort.py @@ -0,0 +1,51 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import is_tf_available +from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow + + +if is_tf_available(): + import numpy as np + import tensorflow as tf + + from transformers import TFAutoModel + + +@require_tf +@require_sentencepiece +@require_tokenizers +class TFBortIntegrationTest(unittest.TestCase): + @slow + def test_output_embeds_base_model(self): + model = TFAutoModel.from_pretrained("amazon/bort") + + input_ids = tf.convert_to_tensor( + [[0, 18077, 4082, 7804, 8606, 6195, 2457, 3321, 11, 10489, 16, 269, 2579, 328, 2]], + dtype=tf.int32, + ) # Schloß Nymphenburg in Munich is really nice! + + output = model(input_ids)["last_hidden_state"] + expected_shape = tf.TensorShape((1, 15, 1024)) + self.assertEqual(output.shape, expected_shape) + # compare the actual values for a slice. + expected_slice = tf.convert_to_tensor( + [[[-0.0349, 0.0436, -1.8654], [-0.6964, 0.0835, -1.7393], [-0.9819, 0.2956, -0.2868]]], + dtype=tf.float32, + ) + + self.assertTrue(np.allclose(output[:, :3, :3].numpy(), expected_slice.numpy(), atol=1e-4))