diff --git a/lit_nlp/examples/models/tfx_model.py b/lit_nlp/examples/models/tfx_model.py deleted file mode 100644 index 29df2ab9..00000000 --- a/lit_nlp/examples/models/tfx_model.py +++ /dev/null @@ -1,73 +0,0 @@ -"""Wrapper for using TFX-generated models within LIT.""" -from collections.abc import Iterator - -import attr -from lit_nlp.api import model as lit_model -from lit_nlp.api import types as lit_types -import tensorflow as tf -import tensorflow_text as tf_text # pylint: disable=unused-import - -_SERVING_DEFAULT_SIGNATURE = 'serving_default' - - -@attr.s(auto_attribs=True) -class TFXModelConfig(object): - """Configuration object for TFX Models.""" - path: str - input_spec: lit_types.Spec - output_spec: lit_types.Spec - signature: str = _SERVING_DEFAULT_SIGNATURE - - -# TODO(b/188036366): Revisit the assumed mapping between input values and -# TF.Examples. -def _inputs_to_serialized_example(input_dict: lit_types.JsonDict): - """Converts the input dictionary to a serialized tf example.""" - feature_dict = {} - for k, v in input_dict.items(): - if not isinstance(v, list): - v = [v] - if isinstance(v[0], int): - feature_dict[k] = tf.train.Feature(int64_list=tf.train.Int64List(value=v)) - elif isinstance(v[0], float): - feature_dict[k] = tf.train.Feature(float_list=tf.train.FloatList(value=v)) - else: - feature_dict[k] = tf.train.Feature( - bytes_list=tf.train.BytesList(value=[bytes(i, 'utf-8') for i in v])) - result = tf.train.Example(features=tf.train.Features(feature=feature_dict)) - return result.SerializeToString() - - -class TFXModel(lit_model.BatchedModel): - """Wrapper for querying a TFX-generated SavedModel.""" - - def __init__(self, config: TFXModelConfig): - self._model = tf.saved_model.load(config.path) - self._signature = config.signature - self._input_spec = config.input_spec - self._output_spec = config.output_spec - - def predict_minibatch( # pytype: disable=signature-mismatch # overriding-return-type-checks - self, inputs: list[lit_types.JsonDict] - ) -> Iterator[lit_types.JsonDict]: - for i in inputs: - filtered_inputs = {k: v for k, v in i.items() if k in self._input_spec} - result = self._model.signatures[self._signature]( - tf.constant([_inputs_to_serialized_example(filtered_inputs)])) - result = { - k: tf.squeeze(v).numpy().tolist() - for k, v in result.items() - if k in self._output_spec - } - for k, v in result.items(): - # If doing Multiclass Prediction for a Binary Classifier. - if (isinstance(self._output_spec[k], lit_types.MulticlassPreds) and - not isinstance(v, list)): - result[k] = [1 - v, v] - yield result - - def input_spec(self) -> lit_types.Spec: - return self._input_spec - - def output_spec(self) -> lit_types.Spec: - return self._output_spec diff --git a/lit_nlp/examples/models/tfx_model_test.py b/lit_nlp/examples/models/tfx_model_test.py deleted file mode 100644 index e7b977c9..00000000 --- a/lit_nlp/examples/models/tfx_model_test.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Tests for lit_nlp.components.tfx_model.""" -import tempfile - -from lit_nlp.api import types as lit_types -from lit_nlp.examples.models import tfx_model -import tensorflow as tf - - -class TfxModelTest(tf.test.TestCase): - - def setUp(self): - super(TfxModelTest, self).setUp() - self._path = tempfile.mkdtemp() - input_layer = tf.keras.layers.Input( - shape=(1,), dtype=tf.string, name='example' - ) - parsed_input = tf.io.parse_example( - tf.reshape(input_layer, [-1]), - {'input_0': tf.io.FixedLenFeature([1], dtype=tf.float32)}) - output_layer = tf.keras.layers.Dense( - 1, name='output_0')( - parsed_input['input_0']) - model = tf.keras.Model(input_layer, output_layer) - model.compile( - optimizer=tf.keras.optimizers.Adam(lr=.001), - loss=tf.keras.losses.binary_crossentropy) - model.save(self._path) - - def testTfxModel(self): - input_spec = {'input_0': lit_types.Scalar()} - output_spec = { - 'output_0': - lit_types.MulticlassPreds(vocab=['0', '1'], parent='input_0') - } - config = tfx_model.TFXModelConfig(self._path, input_spec, output_spec) - lit_model = tfx_model.TFXModel(config) - result = list(lit_model.predict([{'input_0': 0.5}])) - self.assertLen(result, 1) - result = result[0] - self.assertListEqual(list(result.keys()), ['output_0']) - self.assertLen(result['output_0'], 2) - self.assertIsInstance(result['output_0'][0], float) - self.assertIsInstance(result['output_0'][1], float) - self.assertDictEqual(lit_model.input_spec(), input_spec) - self.assertDictEqual(lit_model.output_spec(), output_spec) - - -if __name__ == '__main__': - tf.test.main()