diff --git a/keras_nlp/samplers/sampler.py b/keras_nlp/samplers/sampler.py index 50c58e56ef..a21093aa24 100644 --- a/keras_nlp/samplers/sampler.py +++ b/keras_nlp/samplers/sampler.py @@ -364,6 +364,10 @@ def one_step( ) return prompt + @classmethod + def from_config(cls, config): + return cls(**config) + def get_config(self): return { "jit_compile": self.jit_compile, diff --git a/keras_nlp/samplers/sampler_test.py b/keras_nlp/samplers/sampler_test.py index 1f065885fd..dd33f40093 100644 --- a/keras_nlp/samplers/sampler_test.py +++ b/keras_nlp/samplers/sampler_test.py @@ -17,21 +17,18 @@ import keras_nlp from keras_nlp.samplers.greedy_sampler import GreedySampler +from keras_nlp.samplers.top_k_sampler import TopKSampler class SamplerTest(tf.test.TestCase): def test_serialization(self): - sampler = keras_nlp.samplers.GreedySampler() - config = keras_nlp.samplers.serialize(sampler) - expected_config = { - "class_name": "keras_nlp>GreedySampler", - "config": { - "jit_compile": True, - }, - } - self.assertDictEqual(expected_config, config) + sampler = TopKSampler(k=5) + restored = keras_nlp.samplers.deserialize( + keras_nlp.samplers.serialize(sampler) + ) + self.assertDictEqual(sampler.get_config(), restored.get_config()) - def test_deserialization(self): + def test_get(self): # Test get from string. identifier = "greedy" sampler = keras_nlp.samplers.get(identifier)