Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary property declarations from ModelType #459

Merged
merged 2 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package org.jetbrains.kotlinx.dl.api.inference.keras.loaders

import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Identity
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation

Expand All @@ -20,18 +19,6 @@ public interface ModelType<T : InferenceModel, U : InferenceModel> {
/** Relative path to model for local and S3 buckets storages. */
public val modelRelativePath: String

/**
* If true it means that the second dimension is related to number of channels in image has short notation as `NCWH`,
* otherwise, channels are at the last position and has a short notation as `NHWC`.
*/
public val channelsFirst: Boolean

/**
* An expected channels order for the input image.
* Note: the wrong choice of this parameter can significantly impact the model's performance.
*/
public val inputColorMode: ColorMode

/**
* Preprocessing [Operation] specific for this model type.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

package org.jetbrains.kotlinx.dl.api.inference.imagerecognition

import com.beust.klaxon.JsonArray
import com.beust.klaxon.JsonObject
import com.beust.klaxon.Parser
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelType
import org.jetbrains.kotlinx.dl.dataset.Imagenet
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.image.ImageConverter
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Identity
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.preprocessing.call
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
Expand All @@ -29,14 +27,16 @@ import java.io.IOException
*/
public class ImageRecognitionModel(
internalModel: InferenceModel,
private val modelType: ModelType<out InferenceModel, out InferenceModel>
private val inputColorMode: ColorMode,
private val channelsFirst: Boolean,
private val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>> = Identity()
) : ImageRecognitionModelBase<BufferedImage>(internalModel) {
/** Class labels for ImageNet dataset. */
override val classLabels: Map<Int, String> = Imagenet.V1k.labels()

override val preprocessing: Operation<BufferedImage, Pair<FloatArray, TensorShape>>
get() {
val (width, height) = if (modelType.channelsFirst)
val (width, height) = if (channelsFirst)
Pair(internalModel.inputDimensions[1], internalModel.inputDimensions[2])
else
Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1])
Expand All @@ -47,9 +47,9 @@ public class ImageRecognitionModel(
outputWidth = width.toInt()
interpolation = InterpolationType.BILINEAR
}
.convert { colorMode = modelType.inputColorMode }
.convert { colorMode = inputColorMode }
.toFloatArray {}
.call(modelType.preprocessor)
.call(preprocessor)
}

/**
Expand Down Expand Up @@ -87,6 +87,11 @@ public class ImageRecognitionModel(
saveOptimizerState: Boolean,
copyWeights: Boolean
): ImageRecognitionModel {
return ImageRecognitionModel(internalModel.copy(copiedModelName, saveOptimizerState, copyWeights), modelType)
return ImageRecognitionModel(
internalModel.copy(copiedModelName, saveOptimizerState, copyWeights),
inputColorMode,
channelsFirst,
preprocessor
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,31 @@ import org.jetbrains.kotlinx.dl.api.inference.onnx.objectdetection.SSDLikeModelM
import org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.SinglePoseDetectionModel
import org.jetbrains.kotlinx.dl.dataset.Imagenet
import org.jetbrains.kotlinx.dl.dataset.image.ColorMode
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.preprocessing.Operation
import org.jetbrains.kotlinx.dl.dataset.preprocessing.normalize
import org.jetbrains.kotlinx.dl.dataset.preprocessing.pipeline
import org.jetbrains.kotlinx.dl.dataset.preprocessing.rescale
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

/**
* Set of pretrained mobile-friendly ONNX models
*/
public object ONNXModels {
/** Image classification models */
/** Image classification models.
*
* @property [channelsFirst] If true it means that the second dimension is related to number of channels in image
* has short notation as `NCWH`,
* otherwise, channels are at the last position and has a short notation as `NHWC`.
* @property [inputColorMode] An expected channels order for the input image.
* Note: the wrong choice of this parameter can significantly impact the model's performance.
* */
public sealed class CV<T : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean,
override val inputColorMode: ColorMode = ColorMode.RGB,
protected val channelsFirst: Boolean,
private val inputColorMode: ColorMode = ColorMode.RGB,
) : OnnxModelType<T, ImageRecognitionModel> {
override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
return ImageRecognitionModel(modelHub.loadModel(this) as OnnxInferenceModel, this)
return ImageRecognitionModel(modelHub.loadModel(this) as OnnxInferenceModel, channelsFirst, preprocessor)
}

/**
Expand Down Expand Up @@ -78,19 +88,16 @@ public object ONNXModels {
override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
return ImageRecognitionModel(
modelHub.loadModel(this),
this,
Imagenet.V1001.labels()
channelsFirst,
classLabels = Imagenet.V1001.labels()
)
}
}
}

/** Pose detection models. */
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) : OnnxModelType<T, U> {
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a convolutional neural network model that runs on RGB images and predicts human joint locations of a single person.
* (edges are available in [org.jetbrains.kotlinx.dl.api.inference.onnx.posedetection.edgeKeyPointsPairs]
Expand Down Expand Up @@ -147,11 +154,8 @@ public object ONNXModels {
}

/** Object detection models and preprocessing. */
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) : OnnxModelType<T, U> {
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a real-time neural network for object detection that detects 90 different classes
* (labels are available in [org.jetbrains.kotlinx.dl.dataset.Coco.V2017]).
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
package org.jetbrains.kotlinx.dl.api.inference.onnx.classification

import android.graphics.Bitmap
import org.jetbrains.kotlinx.dl.api.inference.InferenceModel
import org.jetbrains.kotlinx.dl.api.inference.imagerecognition.ImageRecognitionModelBase
import org.jetbrains.kotlinx.dl.api.inference.keras.loaders.ModelType
import org.jetbrains.kotlinx.dl.api.inference.onnx.CameraXCompatibleModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.ExecutionProviderCompatible
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider
import org.jetbrains.kotlinx.dl.dataset.Imagenet
import org.jetbrains.kotlinx.dl.dataset.imagenetLabels
import org.jetbrains.kotlinx.dl.dataset.preprocessing.*
import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

Expand All @@ -18,14 +15,15 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
*/
public open class ImageRecognitionModel(
internalModel: OnnxInferenceModel,
private val modelType: ModelType<out InferenceModel, out InferenceModel>,
private val channelsFirst: Boolean,
private val preprocessor: Operation<Pair<FloatArray, TensorShape>, Pair<FloatArray, TensorShape>> = Identity(),
override val classLabels: Map<Int, String> = Imagenet.V1k.labels()
) : ImageRecognitionModelBase<Bitmap>(internalModel), ExecutionProviderCompatible, CameraXCompatibleModel {
override var targetRotation: Float = 0f

override val preprocessing: Operation<Bitmap, Pair<FloatArray, TensorShape>>
get() {
val (width, height) = if (modelType.channelsFirst)
val (width, height) = if (channelsFirst)
Pair(internalModel.inputDimensions[1], internalModel.inputDimensions[2])
else
Pair(internalModel.inputDimensions[0], internalModel.inputDimensions[1])
Expand All @@ -36,8 +34,8 @@ public open class ImageRecognitionModel(
outputWidth = width.toInt()
}
.rotate { degrees = targetRotation }
.toFloatArray { layout = if (modelType.channelsFirst) TensorLayout.NCHW else TensorLayout.NHWC }
.call(modelType.preprocessor)
.toFloatArray { layout = if (channelsFirst) TensorLayout.NCHW else TensorLayout.NHWC }
.call(preprocessor)
}

override fun initializeWith(vararg executionProviders: ExecutionProvider) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,23 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape

/** Models in the ONNX format and running via ONNX Runtime. */
public object ONNXModels {
/** Image recognition models and preprocessing. */
/** Image recognition models and preprocessing.
*
* @property [channelsFirst] If true it means that the second dimension is related to number of channels in image
* has short notation as `NCWH`,
* otherwise, channels are at the last position and has a short notation as `NHWC`.
* @property [inputColorMode] An expected channels order for the input image.
* Note: the wrong choice of this parameter can significantly impact the model's performance.
* */
public sealed class CV<T : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean,
override val inputColorMode: ColorMode = ColorMode.RGB,
protected val channelsFirst: Boolean,
private val inputColorMode: ColorMode = ColorMode.RGB,
/** If true, model is shipped without last few layers and could be used for transfer learning and fine-tuning with TF Runtime. */
internal var noTop: Boolean = false
) : OnnxModelType<T, ImageRecognitionModel> {
override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
return ImageRecognitionModel(modelHub.loadModel(this), this)
return ImageRecognitionModel(modelHub.loadModel(this), inputColorMode, channelsFirst, preprocessor)
}

/**
Expand Down Expand Up @@ -551,11 +558,7 @@ public object ONNXModels {
}

/** Object detection models and preprocessing. */
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) :
public sealed class ObjectDetection<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a real-time neural network for object detection that detects 80 different classes
Expand Down Expand Up @@ -825,11 +828,7 @@ public object ONNXModels {
}

/** Face alignment models and preprocessing. */
public sealed class FaceAlignment<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) :
public sealed class FaceAlignment<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a neural network for face alignment that take RGB images of faces as input and produces coordinates of 106 faces landmarks.
Expand All @@ -850,11 +849,7 @@ public object ONNXModels {
}

/** Pose detection models. */
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = true,
override val inputColorMode: ColorMode = ColorMode.RGB
) :
public sealed class PoseDetection<T : InferenceModel, U : InferenceModel>(override val modelRelativePath: String) :
OnnxModelType<T, U> {
/**
* This model is a convolutional neural network model that runs on RGB images and predicts human joint locations of a single person.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,18 @@ import org.jetbrains.kotlinx.dl.dataset.shape.TensorShape
* All weights are imported from the `Keras.applications` or `ONNX.models` project and preprocessed with the KotlinDL project.
*/
public object TFModels {
/** Image recognition models and preprocessing. */
/** Image recognition models and preprocessing.
*
* @property [channelsFirst] If true it means that the second dimension is related to number of channels in image
* has short notation as `NCWH`,
* otherwise, channels are at the last position and has a short notation as `NHWC`.
* @property [inputColorMode] An expected channels order for the input image.
* Note: the wrong choice of this parameter can significantly impact the model's performance.
* */
public sealed class CV<T : GraphTrainableModel>(
override val modelRelativePath: String,
override val channelsFirst: Boolean = false,
override val inputColorMode: ColorMode = ColorMode.RGB,
private val channelsFirst: Boolean = false,
private val inputColorMode: ColorMode = ColorMode.RGB,
public var inputShape: IntArray? = null,
internal var noTop: Boolean = false
) : ModelType<T, ImageRecognitionModel> {
Expand All @@ -41,7 +48,8 @@ public object TFModels {
}

override fun pretrainedModel(modelHub: ModelHub): ImageRecognitionModel {
return buildImageRecognitionModel(modelHub, this)
val model = loadModel(modelHub, this)
return ImageRecognitionModel(model, inputColorMode, channelsFirst, preprocessor)
}

/**
Expand Down Expand Up @@ -518,10 +526,7 @@ public object TFModels {
}
}

private fun buildImageRecognitionModel(
modelHub: ModelHub,
modelType: ModelType<out GraphTrainableModel, ImageRecognitionModel>
): ImageRecognitionModel {
private fun loadModel(modelHub: ModelHub, modelType: CV<out GraphTrainableModel>): GraphTrainableModel {
modelHub as TFModelHub
val model = modelHub.loadModel(modelType)
// TODO: this part is not needed for inference (if we could add manually Softmax at the end of the graph)
Expand All @@ -534,7 +539,6 @@ public object TFModels {
val hdfFile = modelHub.loadWeights(modelType)

model.loadWeights(hdfFile)

return ImageRecognitionModel(model, modelType)
return model
}
}