Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix channels ordering for classification models (#400) #401

Merged
merged 2 commits into from
Jul 12, 2022

Conversation

ermolenkodev
Copy link
Collaborator

I have checked what channel order to use for what model utilizing this script.

fun main(args: Array<String> = arrayOf("imagenette2/val/", "imagenette2/lbls.txt")) {
    val (imagenettePath, lblsPath) = args
    val files = File(imagenettePath).walk().filter { it.isFile }.toList()

    val lblsMap = mutableMapOf<String, Int>()
    File(lblsPath).forEachLine {
        val (imgId, lbl) = it.split(",")
        lblsMap[imgId] = lbl.toInt()
    }

    val tfModelHub =
        TFModelHub(cacheDirectory = File("cache/pretrainedModels"))
    val onnxModelHub =
        ONNXModelHub(cacheDirectory = File("cache/pretrainedModels"))

    val modelsToTest = listOf(
        TFModels.CV.ResNet50(),
        TFModels.CV.DenseNet121(),
        TFModels.CV.DenseNet169(),
        TFModels.CV.DenseNet201(),
        TFModels.CV.Inception(),
        TFModels.CV.MobileNet(),
        TFModels.CV.MobileNetV2(),
        TFModels.CV.NASNetLarge(),
        TFModels.CV.NASNetMobile(),
        TFModels.CV.ResNet101(),
        TFModels.CV.ResNet101v2(),
        TFModels.CV.ResNet152(),
        TFModels.CV.ResNet152v2(),
        TFModels.CV.ResNet18(),
        TFModels.CV.ResNet34(),
        TFModels.CV.ResNet50v2(),
        TFModels.CV.VGG16(),
        TFModels.CV.VGG19(),
        TFModels.CV.Xception(),
        ONNXModels.CV.EfficientNetB0(),
        ONNXModels.CV.EfficientNetB1(),
        ONNXModels.CV.EfficientNetB2(),
        ONNXModels.CV.EfficientNetB3(),
        ONNXModels.CV.EfficientNetB4(),
        ONNXModels.CV.EfficientNetB5(),
        ONNXModels.CV.EfficientNetB6(),
        ONNXModels.CV.EfficientNetB7(),
        ONNXModels.CV.EfficientNet4Lite(),
        ONNXModels.CV.ResNet101(),
        ONNXModels.CV.ResNet101v2(),
        ONNXModels.CV.ResNet152(),
        ONNXModels.CV.ResNet152v2(),
        ONNXModels.CV.ResNet18(),
        ONNXModels.CV.ResNet18v2(),
        ONNXModels.CV.ResNet34(),
        ONNXModels.CV.ResNet34v2(),
        ONNXModels.CV.ResNet50(),
        ONNXModels.CV.ResNet50v2(),
        ONNXModels.CV.ResNet50custom,
    )

    val outputLog = File("rgb_bgr_input_test.log")
    for (modelType in modelsToTest) {
        outputLog.appendText("Testing model ${modelType.javaClass.name}.\n")
        println("Testing model ${modelType.javaClass.name}.")

        val model = when (modelType) {
            is TFModels.CV<*> -> {
                if (modelType is TFModels.CV.DenseNet121 || modelType is TFModels.CV.DenseNet169 || modelType is TFModels.CV.DenseNet201) {
                    val model = tfModelHub.loadModel(modelType)

                    model.compile(
                        optimizer = Adam(),
                        loss = Losses.MAE,
                        metric = Metrics.ACCURACY
                    )

                    val hdfFile = tfModelHub.loadWeights(modelType)

                    val weightPaths = listOf(
                        LayerConvOrDensePaths(
                            "conv1_conv",
                            "/conv1/conv/conv1/conv/kernel:0",
                            "/conv1/conv/conv1/conv/bias:0"
                        ),
                        LayerBatchNormPaths(
                            "conv1_bn",
                            "/conv1/bn/conv1/bn/gamma:0",
                            "/conv1/bn/conv1/bn/beta:0",
                            "/conv1/bn/conv1/bn/moving_mean:0",
                            "/conv1/bn/conv1/bn/moving_variance:0"
                        )
                    )
                    model.loadWeightsByPaths(hdfFile, weightPaths, missedWeights = MissedWeightsStrategy.LOAD_CUSTOM_PATH)
                    ImageRecognitionModel(model, modelType)
                } else {
                    modelType.pretrainedModel(tfModelHub)
                }
            }
            is ONNXModels.CV<*> -> modelType.pretrainedModel(onnxModelHub)
            else -> throw RuntimeException()
        }

        val nameToLbl = model.imageNetClassLabels.entries.associate { it.value to it.key }

        val (width, height) = if (modelType.channelsFirst)
            Pair(model.inputDimensions[1], model.inputDimensions[2])
        else
            Pair(model.inputDimensions[0], model.inputDimensions[1])

        val preprocessingBGR: Preprocessing = preprocess {
            transformImage {
                resize {
                    outputHeight = height.toInt()
                    outputWidth = width.toInt()
                    interpolation = InterpolationType.BILINEAR
                }
                convert { colorMode = ColorMode.BGR }
            }
            transformTensor {
                sharpen {
                    modelTypePreprocessing = modelType
                }
            }
        }

        model.let {
            outputLog.appendText("Testing with GBR input\n")
            println("Testing with GBR input")
            var correctCounter = 0
            for (imgFile in files) {
                val className = it.predictObject(imageFile = imgFile, preprocessingBGR)
                val prediction = nameToLbl[className]

                val imgId = imgFile.nameWithoutExtension
                val gt = lblsMap[imgId]

                if (prediction == gt) correctCounter++
            }

            outputLog.appendText("Accuracy in GBR mode - ${correctCounter.toFloat() / files.size}\n")
            println("Accuracy in GBR mode - ${correctCounter.toFloat() / files.size}")
        }

        val preprocessingRGB: Preprocessing = preprocess {
            transformImage {
                resize {
                    outputHeight = height.toInt()
                    outputWidth = width.toInt()
                    interpolation = InterpolationType.BILINEAR
                }
                convert { colorMode = ColorMode.RGB }
            }
            transformTensor {
                sharpen {
                    modelTypePreprocessing = modelType
                }
            }
        }

        model.let {
            outputLog.appendText("Testing with RGB input\n")
            println("Testing with RGB input")
            var correctCounter = 0
            for (imgFile in files) {
                val className = it.predictObject(imageFile = imgFile, preprocessingRGB)
                val prediction = nameToLbl[className]

                val imgId = imgFile.nameWithoutExtension
                val gt = lblsMap[imgId]

                if (prediction == gt) correctCounter++
            }
            outputLog.appendText("Accuracy in RGB mode - ${correctCounter.toFloat() / files.size}\n")
            println("Accuracy in RGB mode - ${correctCounter.toFloat() / files.size}")
        }

        model.close()
        outputLog.appendText("=======================================================\n")
        println("=======================================================")
    }
}

The output is something like this.

Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$ResNet50.
Testing with GBR input
Accuracy in GBR mode - 0.8512102
Testing with RGB input
Accuracy in RGB mode - 0.71592355
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$DenseNet121.
Testing with GBR input
Accuracy in GBR mode - 0.72611463
Testing with RGB input
Accuracy in RGB mode - 0.82573247
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$DenseNet169.
Testing with GBR input
Accuracy in GBR mode - 0.7689172
Testing with RGB input
Accuracy in RGB mode - 0.8690446
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$DenseNet201.
Testing with GBR input
Accuracy in GBR mode - 0.7729936
Testing with RGB input
Accuracy in RGB mode - 0.88764334
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$Inception.
Testing with GBR input
Accuracy in GBR mode - 0.8593631
Testing with RGB input
Accuracy in RGB mode - 0.9087898
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$MobileNet.
Testing with GBR input
Accuracy in GBR mode - 0.8099363
Testing with RGB input
Accuracy in RGB mode - 0.8825478
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$MobileNetV2.
Testing with GBR input
Accuracy in GBR mode - 0.73987263
Testing with RGB input
Accuracy in RGB mode - 0.86394906
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$NASNetLarge.
Testing with GBR input
Accuracy in GBR mode - 0.89859873
Testing with RGB input
Accuracy in RGB mode - 0.93630576
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$NASNetMobile.
Testing with GBR input
Accuracy in GBR mode - 0.75949043
Testing with RGB input
Accuracy in RGB mode - 0.8275159
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$ResNet101.
Testing with GBR input
Accuracy in GBR mode - 0.8601274
Testing with RGB input
Accuracy in RGB mode - 0.7309554
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$ResNet101v2.
Testing with GBR input
Accuracy in GBR mode - 0.72942674
Testing with RGB input
Accuracy in RGB mode - 0.78828025
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$ResNet152.
Testing with GBR input
Accuracy in GBR mode - 0.8703185
Testing with RGB input
Accuracy in RGB mode - 0.7332484
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$ResNet152v2.
Testing with GBR input
Accuracy in GBR mode - 0.74471337
Testing with RGB input
Accuracy in RGB mode - 0.80101913
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$ResNet18.
Testing with GBR input
Accuracy in GBR mode - 0.36866242
Testing with RGB input
Accuracy in RGB mode - 0.4033121
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$ResNet34.
Testing with GBR input
Accuracy in GBR mode - 0.34649682
Testing with RGB input
Accuracy in RGB mode - 0.36484078
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$ResNet50v2.
Testing with GBR input
Accuracy in GBR mode - 0.71133757
Testing with RGB input
Accuracy in RGB mode - 0.77171975
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$VGG16.
Testing with GBR input
Accuracy in GBR mode - 0.76764333
Testing with RGB input
Accuracy in RGB mode - 0.61910826
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$VGG19.
Testing with GBR input
Accuracy in GBR mode - 0.77070063
Testing with RGB input
Accuracy in RGB mode - 0.6557962
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.keras.loaders.TFModels$CV$Xception.
Testing with GBR input
Accuracy in GBR mode - 0.8568153
Testing with RGB input
Accuracy in RGB mode - 0.9031847
=======================================================

Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$EfficientNetB0.
Testing with GBR input
Accuracy in GBR mode - 0.7388535
Testing with RGB input
Accuracy in RGB mode - 0.81452227
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$EfficientNetB1.
10:28:38.208 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
10:30:34.399 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.80407643
Testing with RGB input
Accuracy in RGB mode - 0.8568153
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$EfficientNetB2.
10:33:29.234 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
10:33:48.214 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.8
Testing with RGB input
Accuracy in RGB mode - 0.8532484
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$EfficientNetB3.
10:37:16.233 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
10:38:05.399 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.8509554
Testing with RGB input
Accuracy in RGB mode - 0.9044586
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$EfficientNetB4.
10:43:30.338 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
10:44:34.476 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.8746497
Testing with RGB input
Accuracy in RGB mode - 0.90828025
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$EfficientNetB5.
10:54:34.866 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
10:55:01.024 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.9001274
Testing with RGB input
Accuracy in RGB mode - 0.9261147
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$EfficientNetB6.
11:19:56.814 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
11:20:54.890 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.9008917
Testing with RGB input
Accuracy in RGB mode - 0.9263694
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$EfficientNetB7.
Testing with GBR input
Accuracy in GBR mode - 0.9044586
Testing with RGB input
Accuracy in RGB mode - 0.9357962
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$EfficientNet4Lite.
13:06:02.359 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
13:06:25.830 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.7602548
Testing with RGB input
Accuracy in RGB mode - 0.8512102
======================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet101.
13:11:44.573 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
13:14:16.692 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.755414
Testing with RGB input
Accuracy in RGB mode - 0.86522293
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet101v2.
13:22:17.919 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
13:25:59.289 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.7765605
Testing with RGB input
Accuracy in RGB mode - 0.8873885
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet152.
13:32:29.635 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
13:37:03.118 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.76203823
Testing with RGB input
Accuracy in RGB mode - 0.8624204
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet152v2.
13:46:47.109 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
13:50:28.685 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.79694265
Testing with RGB input
Accuracy in RGB mode - 0.8896815
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet18.
14:03:34.588 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
14:04:14.751 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.62165606
Testing with RGB input
Accuracy in RGB mode - 0.75974524
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet18v2.
14:08:06.736 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
14:08:44.984 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.6244586
Testing with RGB input
Accuracy in RGB mode - 0.7556688
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet34.
14:12:47.103 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
14:14:35.019 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.70726115
Testing with RGB input
Accuracy in RGB mode - 0.82420385
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet34v2.
14:19:47.290 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
14:21:09.790 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.6965605
Testing with RGB input
Accuracy in RGB mode - 0.8371974
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet50.
14:26:33.774 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
14:27:50.324 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.71847135
Testing with RGB input
Accuracy in RGB mode - 0.8402548
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet50v2.
14:33:39.863 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
14:35:59.990 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.72254777
Testing with RGB input
Accuracy in RGB mode - 0.85732484
=======================================================
Testing model org.jetbrains.kotlinx.dl.api.inference.onnx.ONNXModels$CV$ResNet50custom.
14:42:13.507 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is started!
14:43:49.767 [main] DEBUG org.jetbrains.kotlinx.dl.api.inference.loaders.ONNXModelHub - Model loading is finished!
Testing with GBR input
Accuracy in GBR mode - 0.8512102
Testing with RGB input
Accuracy in RGB mode - 0.71592355
=======================================================

@ermolenkodev ermolenkodev self-assigned this Jun 23, 2022
@ermolenkodev ermolenkodev added the bug Something isn't working label Jun 23, 2022
@ermolenkodev ermolenkodev changed the base branch from master to release_0.4 July 12, 2022 07:25
@ermolenkodev ermolenkodev merged commit 7a9d212 into Kotlin:release_0.4 Jul 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants