diff --git a/app/build.gradle b/app/build.gradle index d95e09c..41e64bc 100644 --- a/app/build.gradle +++ b/app/build.gradle @@ -42,5 +42,5 @@ dependencies { }) compile 'com.android.support:appcompat-v7:25.2.0' testCompile 'junit:junit:4.12' - compile files('libs/libandroid_tensorflow_inference_java.jar') + compile 'org.tensorflow:tensorflow-android:1.2.0' } diff --git a/app/libs/libandroid_tensorflow_inference_java.jar b/app/libs/libandroid_tensorflow_inference_java.jar deleted file mode 100755 index 3b8d93b..0000000 Binary files a/app/libs/libandroid_tensorflow_inference_java.jar and /dev/null differ diff --git a/app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java b/app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java index cdf5be5..3d0f3dd 100644 --- a/app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java +++ b/app/src/main/java/com/mindorks/tensorflowexample/TensorFlowImageClassifier.java @@ -16,12 +16,6 @@ package com.mindorks.tensorflowexample; -import android.content.res.AssetManager; -import android.os.Trace; -import android.util.Log; - -import org.tensorflow.contrib.android.TensorFlowInferenceInterface; - import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; @@ -31,6 +25,12 @@ import java.util.PriorityQueue; import java.util.Vector; +import org.tensorflow.contrib.android.TensorFlowInferenceInterface; + +import android.content.res.AssetManager; +import android.support.v4.os.TraceCompat; +import android.util.Log; + /** * Created by amitshekhar on 16/03/17. */ @@ -40,7 +40,7 @@ */ public class TensorFlowImageClassifier implements Classifier { - private static final String TAG = "TensorFlowImageClassifier"; + private static final String TAG = "TFImageClassifier"; // Only return this many results with at least this confidence. private static final int MAX_RESULTS = 3; @@ -58,6 +58,8 @@ public class TensorFlowImageClassifier implements Classifier { private TensorFlowInferenceInterface inferenceInterface; + private boolean runStats = false; + private TensorFlowImageClassifier() { } @@ -96,10 +98,8 @@ public static Classifier create( } br.close(); - c.inferenceInterface = new TensorFlowInferenceInterface(); - if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) { - throw new RuntimeException("TF initialization failed"); - } + c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); + // The shape of the output is [N, NUM_CLASSES], where N is the batch size. int numClasses = (int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1); @@ -120,23 +120,22 @@ public static Classifier create( @Override public List<Recognition> recognizeImage(final float[] pixels) { // Log this method so that it can be analyzed with systrace. - Trace.beginSection("recognizeImage"); + TraceCompat.beginSection("recognizeImage"); // Copy the input data into TensorFlow. - Trace.beginSection("fillNodeFloat"); - inferenceInterface.fillNodeFloat( - inputName, new int[]{inputSize * inputSize}, pixels); - Trace.endSection(); + TraceCompat.beginSection("feed"); + inferenceInterface.feed(inputName, pixels, new long[]{inputSize * inputSize}); + TraceCompat.endSection(); // Run the inference call. - Trace.beginSection("runInference"); - inferenceInterface.runInference(outputNames); - Trace.endSection(); + TraceCompat.beginSection("run"); + inferenceInterface.run(outputNames, runStats); + TraceCompat.endSection(); // Copy the output Tensor back into the output array. - Trace.beginSection("readNodeFloat"); - inferenceInterface.readNodeFloat(outputName, outputs); - Trace.endSection(); + TraceCompat.beginSection("fetch"); + inferenceInterface.fetch(outputName, outputs); + TraceCompat.endSection(); // Find the best classifications. PriorityQueue<Recognition> pq = @@ -161,13 +160,13 @@ public int compare(Recognition lhs, Recognition rhs) { for (int i = 0; i < recognitionsSize; ++i) { recognitions.add(pq.poll()); } - Trace.endSection(); // "recognizeImage" + TraceCompat.endSection(); // "recognizeImage" return recognitions; } @Override public void enableStatLogging(boolean debug) { - inferenceInterface.enableStatLogging(debug); + runStats = debug; } @Override diff --git a/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so b/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so deleted file mode 100755 index 9390465..0000000 Binary files a/app/src/main/jniLibs/armeabi-v7a/libtensorflow_inference.so and /dev/null differ