Skip to content

Commit

Permalink
Simplify registering "built-in" presets (#1818)
Browse files Browse the repository at this point in the history
Instead of registering them with every class a preset should work with,
we just register them with the associated backbone. We can use that
to build `cls.preset` accessors for all library classes. E.g.

```python
keras_nlp.models.PaliGemmaTokenizer.presets
keras_nlp.models.Gpt2Backbone.presets
keras_nlp.models.TextClassifier.presets
```
  • Loading branch information
mattdangerw authored Sep 10, 2024
1 parent 23815d6 commit 6d19bb3
Show file tree
Hide file tree
Showing 35 changed files with 82 additions and 111 deletions.
10 changes: 3 additions & 7 deletions keras_nlp/src/layers/preprocessing/audio_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
PreprocessingLayer,
)
from keras_nlp.src.utils.preset_utils import AUDIO_CONVERTER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import builtin_presets
from keras_nlp.src.utils.preset_utils import find_subclass
from keras_nlp.src.utils.preset_utils import get_preset_loader
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty

Expand Down Expand Up @@ -52,11 +51,8 @@ class AudioConverter(PreprocessingLayer):

@classproperty
def presets(cls):
"""List built-in presets for a `Task` subclass."""
presets = list_presets(cls)
for subclass in list_subclasses(cls):
presets.update(subclass.presets)
return presets
"""List built-in presets for an `AudioConverter` subclass."""
return builtin_presets(cls)

@classmethod
def from_preset(
Expand Down
6 changes: 4 additions & 2 deletions keras_nlp/src/layers/preprocessing/audio_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@

class AudioConverterTest(TestCase):
def test_preset_accessors(self):
pali_gemma_presets = set(WhisperAudioConverter.presets.keys())
whisper_presets = set(WhisperAudioConverter.presets.keys())
all_presets = set(AudioConverter.presets.keys())
self.assertContainsSubset(pali_gemma_presets, all_presets)
self.assertContainsSubset(whisper_presets, all_presets)
self.assertIn("whisper_tiny_en", whisper_presets)
self.assertIn("whisper_tiny_en", all_presets)

@pytest.mark.large
def test_from_preset(self):
Expand Down
10 changes: 3 additions & 7 deletions keras_nlp/src/layers/preprocessing/image_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
PreprocessingLayer,
)
from keras_nlp.src.utils.preset_utils import IMAGE_CONVERTER_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import builtin_presets
from keras_nlp.src.utils.preset_utils import find_subclass
from keras_nlp.src.utils.preset_utils import get_preset_loader
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty

Expand Down Expand Up @@ -55,11 +54,8 @@ class ImageConverter(PreprocessingLayer):

@classproperty
def presets(cls):
"""List built-in presets for a `Task` subclass."""
presets = list_presets(cls)
for subclass in list_subclasses(cls):
presets.update(subclass.presets)
return presets
"""List built-in presets for an `ImageConverter` subclass."""
return builtin_presets(cls)

@classmethod
def from_preset(
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/src/layers/preprocessing/image_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def test_preset_accessors(self):
pali_gemma_presets = set(PaliGemmaImageConverter.presets.keys())
all_presets = set(ImageConverter.presets.keys())
self.assertContainsSubset(pali_gemma_presets, all_presets)
self.assertIn("pali_gemma_3b_mix_224", pali_gemma_presets)
self.assertIn("pali_gemma_3b_mix_224", all_presets)

@pytest.mark.large
def test_from_preset(self):
Expand Down
3 changes: 1 addition & 2 deletions keras_nlp/src/models/albert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.albert.albert_backbone import AlbertBackbone
from keras_nlp.src.models.albert.albert_presets import backbone_presets
from keras_nlp.src.models.albert.albert_tokenizer import AlbertTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (AlbertBackbone, AlbertTokenizer))
register_presets(backbone_presets, AlbertBackbone)
10 changes: 3 additions & 7 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
from keras_nlp.src.utils.keras_utils import assert_quantization_support
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
from keras_nlp.src.utils.preset_utils import MODEL_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import builtin_presets
from keras_nlp.src.utils.preset_utils import get_preset_loader
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import save_metadata
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty
Expand Down Expand Up @@ -141,11 +140,8 @@ def from_config(cls, config):

@classproperty
def presets(cls):
"""List built-in presets for a `Task` subclass."""
presets = list_presets(cls)
for subclass in list_subclasses(cls):
presets.update(subclass.presets)
return presets
"""List built-in presets for a `Backbone` subclass."""
return builtin_presets(cls)

@classmethod
def from_preset(
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/src/models/backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def test_preset_accessors(self):
all_presets = set(Backbone.presets.keys())
self.assertContainsSubset(bert_presets, all_presets)
self.assertContainsSubset(gpt2_presets, all_presets)
self.assertIn("bert_tiny_en_uncased", bert_presets)
self.assertNotIn("bert_tiny_en_uncased", gpt2_presets)
self.assertIn("gpt2_base_en", gpt2_presets)
self.assertNotIn("gpt2_base_en", bert_presets)
self.assertIn("bert_tiny_en_uncased", all_presets)
self.assertIn("gpt2_base_en", all_presets)

@pytest.mark.large
def test_from_preset(self):
Expand Down
3 changes: 1 addition & 2 deletions keras_nlp/src/models/bart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.bart.bart_backbone import BartBackbone
from keras_nlp.src.models.bart.bart_presets import backbone_presets
from keras_nlp.src.models.bart.bart_tokenizer import BartTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (BartBackbone, BartTokenizer))
register_presets(backbone_presets, BartBackbone)
6 changes: 1 addition & 5 deletions keras_nlp/src/models/bert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@

from keras_nlp.src.models.bert.bert_backbone import BertBackbone
from keras_nlp.src.models.bert.bert_presets import backbone_presets
from keras_nlp.src.models.bert.bert_presets import classifier_presets
from keras_nlp.src.models.bert.bert_text_classifier import BertTextClassifier
from keras_nlp.src.models.bert.bert_tokenizer import BertTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (BertBackbone, BertTokenizer))
register_presets(classifier_presets, (BertTextClassifier, BertTokenizer))
register_presets(backbone_presets, BertBackbone)
5 changes: 1 addition & 4 deletions keras_nlp/src/models/bert/bert_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@
},
"kaggle_handle": "kaggle://keras/bert/keras/bert_large_en/2",
},
}

classifier_presets = {
"bert_tiny_en_uncased_sst2": {
"metadata": {
"description": (
Expand All @@ -143,5 +140,5 @@
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/keras/bert_tiny_en_uncased_sst2/4",
}
},
}
3 changes: 1 addition & 2 deletions keras_nlp/src/models/bloom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.bloom.bloom_backbone import BloomBackbone
from keras_nlp.src.models.bloom.bloom_presets import backbone_presets
from keras_nlp.src.models.bloom.bloom_tokenizer import BloomTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (BloomBackbone, BloomTokenizer))
register_presets(backbone_presets, BloomBackbone)
5 changes: 1 addition & 4 deletions keras_nlp/src/models/deberta_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
DebertaV3Backbone,
)
from keras_nlp.src.models.deberta_v3.deberta_v3_presets import backbone_presets
from keras_nlp.src.models.deberta_v3.deberta_v3_tokenizer import (
DebertaV3Tokenizer,
)
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (DebertaV3Backbone, DebertaV3Tokenizer))
register_presets(backbone_presets, DebertaV3Backbone)
5 changes: 1 addition & 4 deletions keras_nlp/src/models/distil_bert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
from keras_nlp.src.models.distil_bert.distil_bert_presets import (
backbone_presets,
)
from keras_nlp.src.models.distil_bert.distil_bert_tokenizer import (
DistilBertTokenizer,
)
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (DistilBertBackbone, DistilBertTokenizer))
register_presets(backbone_presets, DistilBertBackbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/electra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.electra.electra_backbone import ElectraBackbone
from keras_nlp.src.models.electra.electra_presets import backbone_presets
from keras_nlp.src.models.electra.electra_tokenizer import ElectraTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (ElectraBackbone, ElectraTokenizer))
register_presets(backbone_presets, ElectraBackbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/f_net/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.src.models.f_net.f_net_presets import backbone_presets
from keras_nlp.src.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (FNetBackbone, FNetTokenizer))
register_presets(backbone_presets, FNetBackbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/falcon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.falcon.falcon_backbone import FalconBackbone
from keras_nlp.src.models.falcon.falcon_presets import backbone_presets
from keras_nlp.src.models.falcon.falcon_tokenizer import FalconTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (FalconBackbone, FalconTokenizer))
register_presets(backbone_presets, FalconBackbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/gemma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_nlp.src.models.gemma.gemma_presets import backbone_presets
from keras_nlp.src.models.gemma.gemma_tokenizer import GemmaTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (GemmaBackbone, GemmaTokenizer))
register_presets(backbone_presets, GemmaBackbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/gpt2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.src.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.src.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (GPT2Backbone, GPT2Tokenizer))
register_presets(backbone_presets, GPT2Backbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.llama.llama_backbone import LlamaBackbone
from keras_nlp.src.models.llama.llama_presets import backbone_presets
from keras_nlp.src.models.llama.llama_tokenizer import LlamaTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (LlamaBackbone, LlamaTokenizer))
register_presets(backbone_presets, LlamaBackbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_nlp.src.models.llama3.llama3_presets import backbone_presets
from keras_nlp.src.models.llama3.llama3_tokenizer import Llama3Tokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (Llama3Backbone, Llama3Tokenizer))
register_presets(backbone_presets, Llama3Backbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/mistral/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.mistral.mistral_backbone import MistralBackbone
from keras_nlp.src.models.mistral.mistral_presets import backbone_presets
from keras_nlp.src.models.mistral.mistral_tokenizer import MistralTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (MistralBackbone, MistralTokenizer))
register_presets(backbone_presets, MistralBackbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/opt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.opt.opt_backbone import OPTBackbone
from keras_nlp.src.models.opt.opt_presets import backbone_presets
from keras_nlp.src.models.opt.opt_tokenizer import OPTTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (OPTBackbone, OPTTokenizer))
register_presets(backbone_presets, OPTBackbone)
5 changes: 1 addition & 4 deletions keras_nlp/src/models/pali_gemma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
PaliGemmaBackbone,
)
from keras_nlp.src.models.pali_gemma.pali_gemma_presets import backbone_presets
from keras_nlp.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (PaliGemmaBackbone, PaliGemmaTokenizer))
register_presets(backbone_presets, PaliGemmaBackbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/phi3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_nlp.src.models.phi3.phi3_presets import backbone_presets
from keras_nlp.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (Phi3Backbone, Phi3Tokenizer))
register_presets(backbone_presets, Phi3Backbone)
12 changes: 3 additions & 9 deletions keras_nlp/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
PreprocessingLayer,
)
from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import builtin_presets
from keras_nlp.src.utils.preset_utils import find_subclass
from keras_nlp.src.utils.preset_utils import get_preset_loader
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty

Expand Down Expand Up @@ -120,13 +119,8 @@ def from_config(cls, config):

@classproperty
def presets(cls):
presets = list_presets(cls)
# We can also load backbone presets.
if cls.tokenizer_cls is not None:
presets.update(cls.tokenizer_cls.presets)
for subclass in list_subclasses(cls):
presets.update(subclass.presets)
return presets
"""List built-in presets for a `Preprocessor` subclass."""
return builtin_presets(cls)

@classmethod
def from_preset(
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/src/models/preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ def test_preset_accessors(self):
all_presets = set(Preprocessor.presets.keys())
self.assertContainsSubset(bert_presets, all_presets)
self.assertContainsSubset(gpt2_presets, all_presets)
self.assertIn("bert_tiny_en_uncased", bert_presets)
self.assertNotIn("bert_tiny_en_uncased", gpt2_presets)
self.assertIn("gpt2_base_en", gpt2_presets)
self.assertNotIn("gpt2_base_en", bert_presets)
self.assertIn("bert_tiny_en_uncased", all_presets)
self.assertIn("gpt2_base_en", all_presets)

@pytest.mark.large
def test_from_preset(self):
Expand Down
3 changes: 1 addition & 2 deletions keras_nlp/src/models/roberta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.src.models.roberta.roberta_presets import backbone_presets
from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (RobertaBackbone, RobertaTokenizer))
register_presets(backbone_presets, RobertaBackbone)
3 changes: 1 addition & 2 deletions keras_nlp/src/models/t5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.t5.t5_backbone import T5Backbone
from keras_nlp.src.models.t5.t5_presets import backbone_presets
from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (T5Backbone, T5Tokenizer))
register_presets(backbone_presets, T5Backbone)
11 changes: 2 additions & 9 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
from keras_nlp.src.utils.pipeline_model import PipelineModel
from keras_nlp.src.utils.preset_utils import TASK_CONFIG_FILE
from keras_nlp.src.utils.preset_utils import TASK_WEIGHTS_FILE
from keras_nlp.src.utils.preset_utils import builtin_presets
from keras_nlp.src.utils.preset_utils import find_subclass
from keras_nlp.src.utils.preset_utils import get_preset_loader
from keras_nlp.src.utils.preset_utils import list_presets
from keras_nlp.src.utils.preset_utils import list_subclasses
from keras_nlp.src.utils.preset_utils import save_serialized_object
from keras_nlp.src.utils.python_utils import classproperty

Expand Down Expand Up @@ -133,13 +132,7 @@ def from_config(cls, config):
@classproperty
def presets(cls):
"""List built-in presets for a `Task` subclass."""
presets = list_presets(cls)
# We can also load backbone presets.
if cls.backbone_cls is not None:
presets.update(cls.backbone_cls.presets)
for subclass in list_subclasses(cls):
presets.update(subclass.presets)
return presets
return builtin_presets(cls)

@classmethod
def from_preset(
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ def test_preset_accessors(self):
all_presets = set(Task.presets.keys())
self.assertContainsSubset(bert_presets, all_presets)
self.assertContainsSubset(gpt2_presets, all_presets)
self.assertIn("bert_tiny_en_uncased", bert_presets)
self.assertNotIn("bert_tiny_en_uncased", gpt2_presets)
self.assertIn("gpt2_base_en", gpt2_presets)
self.assertNotIn("gpt2_base_en", bert_presets)
self.assertIn("bert_tiny_en_uncased", all_presets)
self.assertIn("gpt2_base_en", all_presets)

@pytest.mark.large
def test_from_preset(self):
Expand Down
3 changes: 1 addition & 2 deletions keras_nlp/src/models/whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from keras_nlp.src.models.whisper.whisper_backbone import WhisperBackbone
from keras_nlp.src.models.whisper.whisper_presets import backbone_presets
from keras_nlp.src.models.whisper.whisper_tokenizer import WhisperTokenizer
from keras_nlp.src.utils.preset_utils import register_presets

register_presets(backbone_presets, (WhisperBackbone, WhisperTokenizer))
register_presets(backbone_presets, WhisperBackbone)
Loading

0 comments on commit 6d19bb3

Please sign in to comment.