Skip to content

Commit

Permalink
Add api to create OnnxInferenceModel from Byte array representing an …
Browse files Browse the repository at this point in the history
…ONNX model (Kotlin#415)
  • Loading branch information
ermolenkodev committed Aug 8, 2022
1 parent 4c934df commit c17a650
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
22 changes: 22 additions & 0 deletions examples/src/test/kotlin/examples/onnx/ModelLoadingTestSuite.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package examples.onnx

import examples.transferlearning.getFileFromResource
import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel
import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertDoesNotThrow
import java.io.File

class ModelLoadingTestSuite {
@Test
fun testLoadingModelFromBytes() {
val lgbmModel: File = getFileFromResource("models/onnx/lgbmSequenceOutput.onnx")
val bytes = lgbmModel.readBytes()

assertDoesNotThrow {
val model = OnnxInferenceModel(bytes)
model.initializeWith(CPU())
model.close()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,15 @@ private const val RESHAPE_MISSED_MESSAGE = "Model input shape is not defined. Ca
*
* @since 0.3
*/
public open class OnnxInferenceModel(private val pathToModel: String) : InferenceModel() {
public open class OnnxInferenceModel private constructor() : InferenceModel() {
public constructor(modelPath: String) : this() {
this.pathToModel = modelPath
}

public constructor(modelBytes: ByteArray) : this() {
this.modelBytes = modelBytes
}

/** Logger for the model. */
private val logger: KLogger = KotlinLogging.logger {}

Expand All @@ -35,6 +43,16 @@ public open class OnnxInferenceModel(private val pathToModel: String) : Inferenc
*/
private val env = OrtEnvironment.getEnvironment()

/**
* Path to the model file.
*/
private lateinit var pathToModel: String

/**
* Model represented as array of bytes.
*/
private lateinit var modelBytes: ByteArray

/** Wraps an ONNX model and allows inference calls. */
private lateinit var session: OrtSession

Expand Down Expand Up @@ -93,7 +111,18 @@ public open class OnnxInferenceModel(private val pathToModel: String) : Inferenc
session.close()
}

session = env.createSession(pathToModel, buildSessionOptions(uniqueProviders))
session = when {
::modelBytes.isInitialized -> {
env.createSession(modelBytes, buildSessionOptions(uniqueProviders))
}
::pathToModel.isInitialized -> {
env.createSession(pathToModel, buildSessionOptions(uniqueProviders))
}
else -> {
throw IllegalStateException("OnnxInferenceModel should be initialized either with a path to the file or with model representation in bytes.")
}
}

executionProvidersInUse = uniqueProviders

initInputOutputInfo()
Expand Down

0 comments on commit c17a650

Please sign in to comment.