From c66fd23464133016cea59d189bebc8841e557c4f Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 7 Jun 2023 17:49:26 +0530 Subject: [PATCH 1/2] Use backbone instead of s2s --- keras_nlp/models/bart/bart_presets.py | 28 +++++++++++++++++++ .../bart/bart_seq_2_seq_lm_preprocessor.py | 8 ++++++ 2 files changed, 36 insertions(+) diff --git a/keras_nlp/models/bart/bart_presets.py b/keras_nlp/models/bart/bart_presets.py index 556a23fc35..a5c6f1230c 100644 --- a/keras_nlp/models/bart/bart_presets.py +++ b/keras_nlp/models/bart/bart_presets.py @@ -70,4 +70,32 @@ "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/merges.txt", "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", }, + "bart_large_en_cnn": { + "metadata": { + "description": ( + "The `bart_large_en` backbone model fine-tuned on the CNN+DM " + "summarization dataset." + ), + "params": 406287360, + "official_name": "BART", + "path": "bart", + "model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md", + }, + "config": { + "vocabulary_size": 50264, + "num_layers": 12, + "num_heads": 16, + "hidden_dim": 1024, + "intermediate_dim": 4096, + "dropout": 0.1, + "max_sequence_length": 1024, + }, + "preprocessor_config": {}, + "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/model.h5", + "weights_hash": "6bfe7e591af8c5699ce6f9f18753af9a", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/vocab.json", + "vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596", + "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/merges.txt", + "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", + }, } diff --git a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py index 76bc6f8a76..5fae41ea61 100644 --- a/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py +++ b/keras_nlp/models/bart/bart_seq_2_seq_lm_preprocessor.py @@ -14,15 +14,19 @@ """BART Seq2Seq LM preprocessor layer.""" +import copy + import tensorflow as tf from absl import logging from keras_nlp.api_export import keras_nlp_export from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor +from keras_nlp.models.bart.bart_presets import backbone_presets from keras_nlp.utils.keras_utils import ( convert_inputs_to_list_of_tensor_segments, ) from keras_nlp.utils.keras_utils import pack_x_y_sample_weight +from keras_nlp.utils.python_utils import classproperty @keras_nlp_export("keras_nlp.models.BartSeq2SeqLMPreprocessor") @@ -188,6 +192,10 @@ def call(self, x, y=None, sample_weight=None): sample_weight = decoder_padding_mask[..., 1:] return pack_x_y_sample_weight(x, y, sample_weight) + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) + def generate_preprocess( self, x, From fc09ad2b98348a1adf02140110445707724303fd Mon Sep 17 00:00:00 2001 From: Abheesht Sharma Date: Wed, 7 Jun 2023 18:13:46 +0530 Subject: [PATCH 2/2] Small changes --- keras_nlp/models/bart/bart_presets.py | 10 +++++----- .../checkpoint_conversion/convert_bart_checkpoints.py | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/keras_nlp/models/bart/bart_presets.py b/keras_nlp/models/bart/bart_presets.py index a5c6f1230c..aa06254c10 100644 --- a/keras_nlp/models/bart/bart_presets.py +++ b/keras_nlp/models/bart/bart_presets.py @@ -91,11 +91,11 @@ "max_sequence_length": 1024, }, "preprocessor_config": {}, - "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/model.h5", - "weights_hash": "6bfe7e591af8c5699ce6f9f18753af9a", - "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/vocab.json", - "vocabulary_hash": "cf410ee085c5c69c957bb1f6d8456596", - "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en/v1/merges.txt", + "weights_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en_cnn/v1/model.h5", + "weights_hash": "99782ecd9365956f016096fef9afd62c", + "vocabulary_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en_cnn/v1/vocab.json", + "vocabulary_hash": "be4d3c6f3f5495426b2c03b334334354", + "merges_url": "https://storage.googleapis.com/keras-nlp/models/bart_large_en_cnn/v1/merges.txt", "merges_hash": "75a37753dd7a28a2c5df80c28bf06e4e", }, } diff --git a/tools/checkpoint_conversion/convert_bart_checkpoints.py b/tools/checkpoint_conversion/convert_bart_checkpoints.py index f5f800f286..1bbc464b6b 100644 --- a/tools/checkpoint_conversion/convert_bart_checkpoints.py +++ b/tools/checkpoint_conversion/convert_bart_checkpoints.py @@ -26,6 +26,7 @@ PRESET_MAP = { "bart_base_en": "facebook/bart-base", "bart_large_en": "facebook/bart-large", + "bart_large_en_cnn": "facebook/bart-large-cnn", } FLAGS = flags.FLAGS