Skip to content

Commit

Permalink
Remove unnecessary property declarations from ModelType (#459)
Browse files Browse the repository at this point in the history
* Allow to create ImageRecognitionModel without ModelType

* Move "channelsFirst" and "inputColorMode" properties to the CV model type

These properties are only used in the classification models.
  • Loading branch information
juliabeliaeva authored Sep 29, 2022
1 parent bad2fdf commit edfe8c4
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package org.jetbrains.kotlinx.dl.api.inference.keras.loaders

import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Identity
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation

Expand All @@ -20,18 +19,6 @@ public interface ModelType<T : InferenceModel, U : InferenceModel> {
/** Relative path to model for local and S3 buckets storages. */
public val modelRelativePath: String

/**
* If true it means that the second dimension is related to number of channels in image has short notation as `NCWH`,
* otherwise, channels are at the last position and has a short notation as `NHWC`.
*/
public val channelsFirst: Boolean

/**
* An expected channels order for the input image.
* Note: the wrong choice of this parameter can significantly impact the model's performance.
*/
public val inputColorMode: ColorMode

/**
* Preprocessing [Operation] specific for this model type.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

package org.jetbrains.kotlinx.dl.api.inference.imagerecognition

import com.beust.klaxon.JsonArray
import com.beust.klaxon.JsonObject
import com.beust.klaxon.Parser
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelType
import org.jetbrains.kotlinx.dl.dataset.Imagenet
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Identity
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.preprocessing.call
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
Expand All @@ -29,14 +27,16 @@ import java.io.IOException
*/
public class ImageRecognitionModel(
internalModel: InferenceModel,
private val modelType: ModelType<out InferenceModel, out InferenceModel>
private val inputColorMode: ColorMode,
private val channelsFirst: Boolean,
private val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>> = Identity()
) : ImageRecognitionModelBase<BufferedImage>(internalModel) {
/** Class labels for ImageNet dataset. */
override val classLabels: Map<Int, String> = Imagenet.V1k.labels()

override val preprocessing: Operation<BufferedImage, Pair<FloatArray, TensorShape>>
get() {
val (width, height) = if (modelType.channelsFirst)
val (width, height) = if (channelsFirst)
Pair(internalModel.inputDimensions[1], internalModel.inputDimensions[2])
else
Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1])
Expand All @@ -47,9 +47,9 @@ public class ImageRecognitionModel(
outputWidth = width.toInt()
interpolation = InterpolationType.BILINEAR
}
.convert { colorMode = modelType.inputColorMode }
.convert { colorMode = inputColorMode }
.toFloatArray {}
.call(modelType.preprocessor)
.call(preprocessor)
}

/**
Expand Down Expand Up @@ -87,6 +87,11 @@ public class ImageRecognitionModel(
saveOptimizerState: Boolean,
copyWeights: Boolean
): ImageRecognitionModel {
return ImageRecognitionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights), modelType)
return ImageRecognitionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
inputColorMode,
channelsFirst,
preprocessor
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,31 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDLikeModelM
import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.SinglePoseDetectionModel
import org.jetbrains.kotlinx.dl.dataset.Imagenet
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.preprocessing.normalize
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
import org.jetbrains.kotlinx.dl.dataset.preprocessing.rescale
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

/**
* Set of pretrained mobile-friendly ONNX models
*/
public object ONNXModels {
/** Image classification models */
/** Image classification models.
*
* @property [channelsFirst] If true it means that the second dimension is related to number of channels in image
* has short notation as `NCWH`,
* otherwise, channels are at the last position and has a short notation as `NHWC`.
* @property [inputColorMode] An expected channels order for the input image.
* Note: the wrong choice of this parameter can significantly impact the model's performance.
* */
public sealed class CV<T : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean,
override val inputColorMode: ColorMode = ColorMode.RGB,
protected val channelsFirst: Boolean,
private val inputColorMode: ColorMode = ColorMode.RGB,
) : OnnxModelType<T, ImageRecognitionModel> {
override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
return ImageRecognitionModel(modelHub.loadModel(this) as OnnxInferenceModel, this)
return ImageRecognitionModel(modelHub.loadModel(this) as OnnxInferenceModel, channelsFirst, preprocessor)
}

/**
Expand Down Expand Up @@ -78,19 +88,16 @@ public object ONNXModels {
override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
return ImageRecognitionModel(
modelHub.loadModel(this),
this,
Imagenet.V1001.labels()
channelsFirst,
classLabels = Imagenet.V1001.labels()
)
}
}
}

/** Pose detection models. */
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) : OnnxModelType<T, U> {
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a convolutional neural network model that runs on RGB images and predicts human joint locations of a single person.
* (edges are available in [org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.edgeKeyPointsPairs]
Expand Down Expand Up @@ -147,11 +154,8 @@ public object ONNXModels {
}

/** Object detection models and preprocessing. */
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) : OnnxModelType<T, U> {
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a real-time neural network for object detection that detects 90 different classes
* (labels are available in [org.jetbrains.kotlinx.dl.dataset.Coco.V2017]).
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx.classification

import android.graphics.Bitmap
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModelBase
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelType
import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.ExecutionProviderCompatible
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.dataset.Imagenet
import org.jetbrains.kotlinx.dl.dataset.imagenetLabels
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

Expand All @@ -18,14 +15,15 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
*/
public open class ImageRecognitionModel(
internalModel: OnnxInferenceModel,
private val modelType: ModelType<out InferenceModel, out InferenceModel>,
private val channelsFirst: Boolean,
private val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>> = Identity(),
override val classLabels: Map<Int, String> = Imagenet.V1k.labels()
) : ImageRecognitionModelBase<Bitmap>(internalModel), ExecutionProviderCompatible, CameraXCompatibleModel {
override var targetRotation: Float = 0f

override val preprocessing: Operation<Bitmap, Pair<FloatArray, TensorShape>>
get() {
val (width, height) = if (modelType.channelsFirst)
val (width, height) = if (channelsFirst)
Pair(internalModel.inputDimensions[1], internalModel.inputDimensions[2])
else
Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1])
Expand All @@ -36,8 +34,8 @@ public open class ImageRecognitionModel(
outputWidth = width.toInt()
}
.rotate { degrees = targetRotation }
.toFloatArray { layout = if (modelType.channelsFirst) TensorLayout.NCHW else TensorLayout.NHWC }
.call(modelType.preprocessor)
.toFloatArray { layout = if (channelsFirst) TensorLayout.NCHW else TensorLayout.NHWC }
.call(preprocessor)
}

override fun initializeWith(vararg executionProviders: ExecutionProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,23 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

/** Models in the ONNX format and running via ONNX Runtime. */
public object ONNXModels {
/** Image recognition models and preprocessing. */
/** Image recognition models and preprocessing.
*
* @property [channelsFirst] If true it means that the second dimension is related to number of channels in image
* has short notation as `NCWH`,
* otherwise, channels are at the last position and has a short notation as `NHWC`.
* @property [inputColorMode] An expected channels order for the input image.
* Note: the wrong choice of this parameter can significantly impact the model's performance.
* */
public sealed class CV<T : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean,
override val inputColorMode: ColorMode = ColorMode.RGB,
protected val channelsFirst: Boolean,
private val inputColorMode: ColorMode = ColorMode.RGB,
/** If true, model is shipped without last few layers and could be used for transfer learning and fine-tuning with TF Runtime. */
internal var noTop: Boolean = false
) : OnnxModelType<T, ImageRecognitionModel> {
override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
return ImageRecognitionModel(modelHub.loadModel(this), this)
return ImageRecognitionModel(modelHub.loadModel(this), inputColorMode, channelsFirst, preprocessor)
}

/**
Expand Down Expand Up @@ -551,11 +558,7 @@ public object ONNXModels {
}

/** Object detection models and preprocessing. */
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) :
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a real-time neural network for object detection that detects 80 different classes
Expand Down Expand Up @@ -825,11 +828,7 @@ public object ONNXModels {
}

/** Face alignment models and preprocessing. */
public sealed class FaceAlignment<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) :
public sealed class FaceAlignment<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a neural network for face alignment that take RGB images of faces as input and produces coordinates of 106 faces landmarks.
Expand All @@ -850,11 +849,7 @@ public object ONNXModels {
}

/** Pose detection models. */
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) :
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a convolutional neural network model that runs on RGB images and predicts human joint locations of a single person.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,18 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
* All weights are imported from the `Keras.applications` or `ONNX.models` project and preprocessed with the KotlinDL project.
*/
public object TFModels {
/** Image recognition models and preprocessing. */
/** Image recognition models and preprocessing.
*
* @property [channelsFirst] If true it means that the second dimension is related to number of channels in image
* has short notation as `NCWH`,
* otherwise, channels are at the last position and has a short notation as `NHWC`.
* @property [inputColorMode] An expected channels order for the input image.
* Note: the wrong choice of this parameter can significantly impact the model's performance.
* */
public sealed class CV<T : GraphTrainableModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = false,
override val inputColorMode: ColorMode = ColorMode.RGB,
private val channelsFirst: Boolean = false,
private val inputColorMode: ColorMode = ColorMode.RGB,
public var inputShape: IntArray? = null,
internal var noTop: Boolean = false
) : ModelType<T, ImageRecognitionModel> {
Expand All @@ -41,7 +48,8 @@ public object TFModels {
}

override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
return buildImageRecognitionModel(modelHub, this)
val model = loadModel(modelHub, this)
return ImageRecognitionModel(model, inputColorMode, channelsFirst, preprocessor)
}

/**
Expand Down Expand Up @@ -518,10 +526,7 @@ public object TFModels {
}
}

private fun buildImageRecognitionModel(
modelHub: ModelHub,
modelType: ModelType<out GraphTrainableModel, ImageRecognitionModel>
): ImageRecognitionModel {
private fun loadModel(modelHub: ModelHub, modelType: CV<out GraphTrainableModel>): GraphTrainableModel {
modelHub as TFModelHub
val model = modelHub.loadModel(modelType)
// TODO: this part is not needed for inference (if we could add manually Softmax at the end of the graph)
Expand All @@ -534,7 +539,6 @@ public object TFModels {
val hdfFile = modelHub.loadWeights(modelType)

model.loadWeights(hdfFile)

return ImageRecognitionModel(model, modelType)
return model
}
}

0 comments on commit edfe8c4

Please sign in to comment.