diff --git a/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb b/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb new file mode 100644 index 00000000000000..6d618df1318d07 --- /dev/null +++ b/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb @@ -0,0 +1,517 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "eB-dvsMI09O4" + }, + "source": [ + "##### Copyright 2021 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "mwvC53CC1K3n" + }, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0WlR2-Y3v1mf" + }, + "source": [ + "End-to-end Example to show how a pre-trained FLAX model can be downloaded, exported and ran.\n", + "===============\n", + "1. The model used is Resnet50, downloaded from huggingface pre-trained models\n", + "2. Test image is of a CAT\n", + "3. `orbax-export` API is used to export the JAX Module to a TF Saved Model, along with image pre/post-processing functions.\n", + "4. TFLite Converter \u0026 Runtime are used to convert the TF Saved Model to .tflite and run on the test image.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zQXWQ7y11eIR" + }, + "source": [ + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/lite/examples/jax_conversion/jax_to_tflite_resnet50\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/tensorflow/tensorflow/lite/g3doc/examples/jax_conversion/jax_to_tflite_resnet50.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gq1373VwVlKV" + }, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TUOym6hJnTxC" + }, + "source": [ + "```\n", + "!pip install orbax-export\n", + "\n", + "!pip install tf-nightly\n", + "\n", + "!pip install --upgrade jax jaxlib\n", + "\n", + "!pip install transformers flax\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hDzBKqEfrafW" + }, + "source": [ + "## Show Test Image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2IOV4p_jripo" + }, + "outputs": [], + "source": [ + "from PIL import Image\n", + "import jax\n", + "import requests\n", + "\n", + "url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/astrid_l_shaped.jpg\"\n", + "image = Image.open(requests.get(url, stream=True).raw)\n", + "image" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SPcMyI8ERmIB" + }, + "source": [ + "## Download and test pre-trained Resnet50 FLAX model from HuggingFace" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a__LrY62PUtg" + }, + "outputs": [], + "source": [ + "from transformers import ConvNextImageProcessor, FlaxResNetForImageClassification\n", + "\n", + "image_processor = ConvNextImageProcessor.from_pretrained(\"microsoft/resnet-50\")\n", + "model = FlaxResNetForImageClassification.from_pretrained(\"microsoft/resnet-50\")\n", + "\n", + "inputs = image_processor(images=image, return_tensors=\"np\")\n", + "outputs = model(**inputs)\n", + "logits = outputs.logits\n", + "\n", + "# model predicts one of the 1000 ImageNet classes\n", + "predicted_class_idx = jax.numpy.argmax(logits, axis=-1)\n", + "print(\"Predicted class:\", model.config.id2label[predicted_class_idx.item()])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MKJHRRjGR1rC" + }, + "source": [ + "# Create a JAX wrapper for the model\n", + "Wrapper is needed in order to comply with TFLite accepts inputs. TFLite accets a tensor or a tuple-of-tensors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Af2zv1qFSOR3" + }, + "outputs": [], + "source": [ + "import flax.linen as nn\n", + "from transformers import FlaxResNetForImageClassification\n", + "\n", + "\n", + "class Resnet50Wrapper(nn.Module):\n", + " pretrained_model_name: str = \"microsoft/resnet-50\" # Pre-trained model name\n", + "\n", + " def setup(self):\n", + " # Initialize the pre-trained ResNet50 model\n", + " self.model = FlaxResNetForImageClassification.from_pretrained(\n", + " self.pretrained_model_name\n", + " )\n", + "\n", + " def __call__(self, inputs):\n", + " # Process input images through the ResNet50 model\n", + " outputs = self.model(pixel_values=inputs)\n", + "\n", + " # Return logits or directly apply softmax for probabilities (optional)\n", + " return outputs.logits" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mrKtRfkhf73M" + }, + "source": [ + "## Tensorflow image pre-processor function\n", + "This essentialliy implements the underlying logic of the `ConvNextImageProcessor` class in huggingface transformers:\n", + "\u003e ```\n", + "\u003e image_processor = ConvNextImageProcessor.from_pretrained(\"microsoft/resnet-50\")\n", + "\u003e inputs = image_processor(images=image, return_tensors=\"np\")\n", + "\u003e ```\n", + "\n", + "This utility can be reused later during orbax-export for `tf_preprocessing`.\n", + "\n", + "Note: We can perfectly use the result of `ConvNextImageProcessor` to run a TFLite model. But this example would like to showcase how `orbax-export` helps handle input/output pre/post-processing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "h85-Qb1PgANA" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import numpy as np\n", + "\n", + "\n", + "def resnet_image_processor(image_tensor):\n", + " # 1. Resize and Cast to Float32\n", + " image_resized = tf.image.resize(\n", + " image_tensor, (224, 224), method=tf.image.ResizeMethod.BILINEAR\n", + " )\n", + " image_float = tf.cast(image_resized, tf.float32)\n", + "\n", + " # 2. Normalize (Using TensorFlow Constants)\n", + " mean = tf.constant([0.485, 0.456, 0.406])\n", + " std = tf.constant([0.229, 0.224, 0.225])\n", + " image_normalized = (image_float / 255.0 - mean) / std\n", + "\n", + " # 3. Transpose for Channel-First Format\n", + " image_transposed = tf.transpose(image_normalized, perm=[2, 0, 1])\n", + "\n", + " # 4. Add Batch Dimension\n", + " return tf.expand_dims(image_transposed, axis=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3GD4Wb2K3VT_" + }, + "source": [ + "## Initialize the Jax model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2wHvNFpiThVL" + }, + "outputs": [], + "source": [ + "# Initialize the JAX Model\n", + "jax_model = Resnet50Wrapper()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uwP7XjS43M7W" + }, + "source": [ + "## Validate the wrapped JAX Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Gl8mypra3MGw" + }, + "outputs": [], + "source": [ + "# Convert the raw image values to RGB tensor\n", + "raw_image_tensor = tf.convert_to_tensor(np.array(image, dtype=np.float32))\n", + "\n", + "# Appy the above TF imape preprocessing to get an input tensor supported by Resnet50\n", + "input_tensor = resnet_image_processor(raw_image_tensor)\n", + "\n", + "# Run the JAX model\n", + "jax_logits = jax_model.apply({}, input_tensor.numpy())\n", + "\n", + "jax_predicted_class_idx = jax.numpy.argmax(jax_logits, axis=-1)\n", + "print(\"Predicted class:\", model.config.id2label[jax_predicted_class_idx.item()])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lKZst83Mu6TD" + }, + "outputs": [], + "source": [ + "raw_image_tensor.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zuAl8ubTTKxZ" + }, + "source": [ + "# Export to TFLite model and run" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Pz61IKnwjhbn" + }, + "source": [ + "## Export the JAX to TF Saved Model using orbax-export" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Cwq1rxEZjpoU" + }, + "outputs": [], + "source": [ + "from orbax.export import ExportManager, JaxModule, ServingConfig\n", + "\n", + "# Wrap the model params and function into a JaxModule.\n", + "jax_module = JaxModule({}, jax_model.apply, trainable=False)\n", + "\n", + "# Specify the serving configuration and export the model.\n", + "serving_config = ServingConfig(\n", + " \"serving_default\",\n", + " input_signature=[tf.TensorSpec([480, 640, 3], tf.float32, name=\"inputs\")],\n", + " tf_preprocessor=resnet_image_processor,\n", + " tf_postprocessor=lambda x: tf.argmax(x, axis=-1),\n", + ")\n", + "\n", + "export_manager = ExportManager(jax_module, [serving_config])\n", + "\n", + "saved_model_dir = \"resnet50_saved_model\"\n", + "export_manager.save(saved_model_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O00xe8MevVLj" + }, + "source": [ + "### Convert to a TFLite Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Uf2-ZMECmtJ3" + }, + "outputs": [], + "source": [ + "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n", + "tflite_model = converter.convert()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "n-fBAd64vZse" + }, + "source": [ + "### Apply TFLite Runtime API on the `raw_image_tensor`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "unFFby0Zmup-" + }, + "outputs": [], + "source": [ + "def run_tflite_model(tflite_model_content, input_tensor):\n", + " interpreter = tf.lite.Interpreter(model_content=tflite_model_content)\n", + " interpreter.allocate_tensors()\n", + "\n", + " input_details = interpreter.get_input_details()[0]\n", + " interpreter.set_tensor(input_details[\"index\"], input_tensor)\n", + " interpreter.invoke()\n", + "\n", + " output_details = interpreter.get_output_details()\n", + " return interpreter.get_tensor(output_details[0][\"index\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z9Rwk8iFvSBO" + }, + "outputs": [], + "source": [ + "output_data = run_tflite_model(tflite_model, raw_image_tensor)\n", + "print(\"Predicted class:\", model.config.id2label[output_data[0]])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NbtCrNLYKpK8" + }, + "source": [ + "## Export to TF Saved Model without pre/post-processing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qPLmCDniKvV3" + }, + "outputs": [], + "source": [ + "saved_model_dir_2 = \"resnet50_saved_model_1\"\n", + "\n", + "tf.saved_model.save(\n", + " jax_module,\n", + " saved_model_dir_2,\n", + " signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n", + " tf.TensorSpec([1, 3, 224, 224], tf.float32, name=\"inputs\")\n", + " ),\n", + " options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Wzp8M2ebLYoH" + }, + "outputs": [], + "source": [ + "converter_1 = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir_2)\n", + "tflite_model_1 = converter_1.convert()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3Di5ikJYMcKy" + }, + "outputs": [], + "source": [ + "output_data_1 = run_tflite_model(tflite_model_1, input_tensor)\n", + "tfl_predicted_class_idx_1 = tf.argmax(output_data_1, axis=-1).numpy()\n", + "print(\"Predicted class:\", model.config.id2label[tfl_predicted_class_idx_1[0]])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_SwsbUH_Smt7" + }, + "source": [ + "## Export from TF Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2MbcN1vjSt5E" + }, + "outputs": [], + "source": [ + "converter_2 = tf.lite.TFLiteConverter.from_concrete_functions(\n", + " [\n", + " jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(\n", + " tf.TensorSpec([1, 3, 224, 224], tf.float32, name=\"inputs\")\n", + " )\n", + " ]\n", + ")\n", + "tflite_model_2 = converter_2.convert()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WdM3p03AS-xq" + }, + "outputs": [], + "source": [ + "output_data_2 = run_tflite_model(tflite_model_2, input_tensor)\n", + "tfl_predicted_class_idx_2 = tf.argmax(output_data_2, axis=-1).numpy()\n", + "print(\"Predicted class:\", model.config.id2label[tfl_predicted_class_idx_2[0]])" + ] + } + ], + "metadata": { + "colab": { + "name": "jax_to_tflite_resnet50.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}