diff --git a/examples/src/test/kotlin/examples/onnx/ModelLoadingTestSuite.kt b/examples/src/test/kotlin/examples/onnx/ModelLoadingTestSuite.kt new file mode 100644 index 000000000..f8672fbec --- /dev/null +++ b/examples/src/test/kotlin/examples/onnx/ModelLoadingTestSuite.kt @@ -0,0 +1,22 @@ +package examples.onnx + +import examples.transferlearning.getFileFromResource +import org.jetbrains.kotlinx.dl.api.inference.onnx.OnnxInferenceModel +import org.jetbrains.kotlinx.dl.api.inference.onnx.executionproviders.ExecutionProvider.CPU +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertDoesNotThrow +import java.io.File + +class ModelLoadingTestSuite { + @Test + fun testLoadingModelFromBytes() { + val lgbmModel: File = getFileFromResource("models/onnx/lgbmSequenceOutput.onnx") + val bytes = lgbmModel.readBytes() + + assertDoesNotThrow { + val model = OnnxInferenceModel(bytes) + model.initializeWith(CPU()) + model.close() + } + } +} diff --git a/onnx/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt b/onnx/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt index 8e6a75c12..7323d1ef3 100644 --- a/onnx/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt +++ b/onnx/src/main/kotlin/org/jetbrains/kotlinx/dl/api/inference/onnx/OnnxInferenceModel.kt @@ -25,7 +25,15 @@ private const val RESHAPE_MISSED_MESSAGE = "Model input shape is not defined. Ca * * @since 0.3 */ -public open class OnnxInferenceModel(private val pathToModel: String) : InferenceModel() { +public open class OnnxInferenceModel private constructor() : InferenceModel() { + public constructor(modelPath: String) : this() { + this.pathToModel = modelPath + } + + public constructor(modelBytes: ByteArray) : this() { + this.modelBytes = modelBytes + } + /** Logger for the model. */ private val logger: KLogger = KotlinLogging.logger {} @@ -35,6 +43,16 @@ public open class OnnxInferenceModel(private val pathToModel: String) : Inferenc */ private val env = OrtEnvironment.getEnvironment() + /** + * Path to the model file. + */ + private lateinit var pathToModel: String + + /** + * Model represented as array of bytes. + */ + private lateinit var modelBytes: ByteArray + /** Wraps an ONNX model and allows inference calls. */ private lateinit var session: OrtSession @@ -93,7 +111,18 @@ public open class OnnxInferenceModel(private val pathToModel: String) : Inferenc session.close() } - session = env.createSession(pathToModel, buildSessionOptions(uniqueProviders)) + session = when { + ::modelBytes.isInitialized -> { + env.createSession(modelBytes, buildSessionOptions(uniqueProviders)) + } + ::pathToModel.isInitialized -> { + env.createSession(pathToModel, buildSessionOptions(uniqueProviders)) + } + else -> { + throw IllegalStateException("OnnxInferenceModel should be initialized either with a path to the file or with model representation in bytes.") + } + } + executionProvidersInUse = uniqueProviders initInputOutputInfo()