diff --git a/java/build.gradle b/java/build.gradle
index fa76ae927fe94..db4aa50d84627 100644
--- a/java/build.gradle
+++ b/java/build.gradle
@@ -167,6 +167,7 @@ test {
java {
dependsOn spotlessJava
}
+ forkEvery 1 // Forces each test class to be run in a separate JVM, which is necessary for testing the environment thread pool
useJUnitPlatform()
if (cmakeBuildDir != null) {
workingDir cmakeBuildDir
diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java
index fa199c9389c19..f71d26c8bfffa 100644
--- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java
+++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java
@@ -333,7 +333,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, Object data) throws Or
*/
static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Object data)
throws OrtException {
- if ((!env.isClosed()) && (!allocator.isClosed())) {
+ if (!allocator.isClosed()) {
TensorInfo info = TensorInfo.constructFromJavaArray(data);
if (info.type == OnnxJavaType.STRING) {
if (info.shape.length == 0) {
@@ -403,7 +403,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, String[] data, long[]
*/
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, String[] data, long[] shape) throws OrtException {
- if ((!env.isClosed()) && (!allocator.isClosed())) {
+ if (!allocator.isClosed()) {
TensorInfo info =
new TensorInfo(
shape,
@@ -451,7 +451,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, FloatBuffer data, long
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, FloatBuffer data, long[] shape)
throws OrtException {
- if ((!env.isClosed()) && (!allocator.isClosed())) {
+ if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.FLOAT;
return createTensor(type, allocator, data, shape);
} else {
@@ -492,7 +492,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, DoubleBuffer data, lon
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, DoubleBuffer data, long[] shape)
throws OrtException {
- if ((!env.isClosed()) && (!allocator.isClosed())) {
+ if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.DOUBLE;
return createTensor(type, allocator, data, shape);
} else {
@@ -571,7 +571,7 @@ public static OnnxTensor createTensor(
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, ByteBuffer data, long[] shape, OnnxJavaType type)
throws OrtException {
- if ((!env.isClosed()) && (!allocator.isClosed())) {
+ if (!allocator.isClosed()) {
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
@@ -611,7 +611,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape)
throws OrtException {
- if ((!env.isClosed()) && (!allocator.isClosed())) {
+ if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT16;
return createTensor(type, allocator, data, shape);
} else {
@@ -652,7 +652,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, IntBuffer data, long[]
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, IntBuffer data, long[] shape)
throws OrtException {
- if ((!env.isClosed()) && (!allocator.isClosed())) {
+ if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT32;
return createTensor(type, allocator, data, shape);
} else {
@@ -693,7 +693,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, LongBuffer data, long[
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, LongBuffer data, long[] shape)
throws OrtException {
- if ((!env.isClosed()) && (!allocator.isClosed())) {
+ if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT64;
return createTensor(type, allocator, data, shape);
} else {
diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java
index dbf47ae027ead..b5f1fb21af28d 100644
--- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java
+++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
@@ -7,14 +7,18 @@
import ai.onnxruntime.OrtSession.SessionOptions;
import java.io.IOException;
import java.util.EnumSet;
-import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;
/**
* The host object for the onnx-runtime system. Can create {@link OrtSession}s which encapsulate
* specific models.
+ *
+ *
There can be at most one OrtEnvironment object created in a JVM lifetime. This class
+ * implements {@link AutoCloseable} as before for backwards compatibility with 1.10 and earlier, but
+ * the {@link #close} method is a no-op. The environment is closed by a JVM shutdown hook registered
+ * on construction.
*/
-public class OrtEnvironment implements AutoCloseable {
+public final class OrtEnvironment implements AutoCloseable {
private static final Logger logger = Logger.getLogger(OrtEnvironment.class.getName());
@@ -30,8 +34,6 @@ public class OrtEnvironment implements AutoCloseable {
private static volatile OrtEnvironment INSTANCE;
- private static final AtomicInteger refCount = new AtomicInteger();
-
private static volatile OrtLoggingLevel curLogLevel;
private static volatile String curLoggingName;
@@ -47,8 +49,6 @@ public static synchronized OrtEnvironment getEnvironment() {
// If there's no instance, create one.
return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME);
} else {
- // else return the current one.
- refCount.incrementAndGet();
return INSTANCE;
}
}
@@ -106,7 +106,6 @@ public static synchronized OrtEnvironment getEnvironment(
"Tried to change OrtEnvironment's logging level or name while a reference exists.");
}
}
- refCount.incrementAndGet();
return INSTANCE;
}
@@ -131,7 +130,6 @@ public static synchronized OrtEnvironment getEnvironment(
} catch (OrtException e) {
throw new IllegalStateException("Failed to create OrtEnvironment", e);
}
- refCount.incrementAndGet();
return INSTANCE;
} else {
// As the thread pool state is unknown, and that's probably not what the user wanted.
@@ -144,8 +142,6 @@ public static synchronized OrtEnvironment getEnvironment(
final OrtAllocator defaultAllocator;
- private volatile boolean closed = false;
-
/**
* Create an OrtEnvironment using a default name.
*
@@ -165,6 +161,8 @@ private OrtEnvironment() throws OrtException {
private OrtEnvironment(OrtLoggingLevel loggingLevel, String name) throws OrtException {
nativeHandle = createHandle(OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name);
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
+ Runtime.getRuntime()
+ .addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
}
/**
@@ -181,6 +179,8 @@ private OrtEnvironment(OrtLoggingLevel loggingLevel, String name, ThreadingOptio
createHandle(
OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name, threadOptions.nativeHandle);
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
+ Runtime.getRuntime()
+ .addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
}
/**
@@ -219,11 +219,7 @@ public OrtSession createSession(String modelPath, SessionOptions options) throws
*/
OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOptions options)
throws OrtException {
- if (!closed) {
- return new OrtSession(this, modelPath, allocator, options);
- } else {
- throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
- }
+ return new OrtSession(this, modelPath, allocator, options);
}
/**
@@ -262,11 +258,7 @@ public OrtSession createSession(byte[] modelArray) throws OrtException {
*/
OrtSession createSession(byte[] modelArray, OrtAllocator allocator, SessionOptions options)
throws OrtException {
- if (!closed) {
- return new OrtSession(this, modelArray, allocator, options);
- } else {
- throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
- }
+ return new OrtSession(this, modelArray, allocator, options);
}
/**
@@ -279,41 +271,11 @@ public void setTelemetry(boolean sendTelemetry) throws OrtException {
setTelemetry(OnnxRuntime.ortApiHandle, nativeHandle, sendTelemetry);
}
- /**
- * Is this environment closed?
- *
- * @return True if the environment is closed.
- */
- public boolean isClosed() {
- return closed;
- }
-
@Override
public String toString() {
return "OrtEnvironment(name=" + curLoggingName + ",logLevel=" + curLogLevel + ")";
}
- /**
- * Closes the OrtEnvironment. If this is the last reference to the environment then it closes the
- * native handle.
- *
- * @throws OrtException If the close failed.
- */
- @Override
- public synchronized void close() throws OrtException {
- synchronized (refCount) {
- int curCount = refCount.get();
- if (curCount != 0) {
- refCount.decrementAndGet();
- }
- if (curCount == 1) {
- close(OnnxRuntime.ortApiHandle, nativeHandle);
- closed = true;
- INSTANCE = null;
- }
- }
- }
-
/**
* Gets the providers available in this environment.
*
@@ -377,11 +339,22 @@ private static native long createHandle(
private static native void setTelemetry(long apiHandle, long nativeHandle, boolean sendTelemetry)
throws OrtException;
+ /** Close is a no-op on OrtEnvironment since ORT 1.11. */
+ @Override
+ public void close() {}
+
/**
* Controls the global thread pools in the environment. Only used if the session is constructed
* using an options with {@link OrtSession.SessionOptions#disablePerSessionThreads()} set.
*/
public static final class ThreadingOptions implements AutoCloseable {
+ static {
+ try {
+ OnnxRuntime.init();
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to load onnx-runtime library", e);
+ }
+ }
private final long nativeHandle;
@@ -486,4 +459,24 @@ private native void setGlobalDenormalAsZero(long apiHandle, long nativeHandle)
private native void closeThreadingOptions(long apiHandle, long nativeHandle);
}
+
+ private static final class OrtEnvCloser implements Runnable {
+
+ private final long apiHandle;
+ private final long nativeHandle;
+
+ OrtEnvCloser(long apiHandle, long nativeHandle) {
+ this.apiHandle = apiHandle;
+ this.nativeHandle = nativeHandle;
+ }
+
+ @Override
+ public void run() {
+ try {
+ OrtEnvironment.close(apiHandle, nativeHandle);
+ } catch (OrtException e) {
+ System.err.println("Error closing OrtEnvironment, " + e);
+ }
+ }
+ }
}
diff --git a/java/src/test/java/ai/onnxruntime/EnvironmentThreadPoolTest.java b/java/src/test/java/ai/onnxruntime/EnvironmentThreadPoolTest.java
new file mode 100644
index 0000000000000..514976723840b
--- /dev/null
+++ b/java/src/test/java/ai/onnxruntime/EnvironmentThreadPoolTest.java
@@ -0,0 +1,88 @@
+/*
+ * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
+ * Licensed under the MIT License.
+ */
+package ai.onnxruntime;
+
+import static ai.onnxruntime.TestHelpers.getResourcePath;
+import static ai.onnxruntime.TestHelpers.loadTensorFromFile;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+
+import java.nio.file.Path;
+import java.util.HashMap;
+import java.util.Map;
+import org.junit.jupiter.api.Test;
+
+/** This test is in a separate class to ensure it is run in a clean JVM. */
+public class EnvironmentThreadPoolTest {
+
+ @Test
+ public void environmentThreadPoolTest() throws OrtException {
+ Path squeezeNet = getResourcePath("/squeezenet.onnx");
+ String modelPath = squeezeNet.toString();
+ float[] inputData = loadTensorFromFile(getResourcePath("/bench.in"));
+ float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out"));
+ Map container = new HashMap<>();
+
+ OrtEnvironment.ThreadingOptions threadOpts = new OrtEnvironment.ThreadingOptions();
+ threadOpts.setGlobalInterOpNumThreads(2);
+ threadOpts.setGlobalIntraOpNumThreads(2);
+ threadOpts.setGlobalDenormalAsZero();
+ threadOpts.setGlobalSpinControl(true);
+ OrtEnvironment env =
+ OrtEnvironment.getEnvironment(
+ OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "environmentThreadPoolTest", threadOpts);
+ try (OrtSession.SessionOptions options = new OrtSession.SessionOptions();
+ OrtSession.SessionOptions disableThreadOptions = new OrtSession.SessionOptions()) {
+ disableThreadOptions.disablePerSessionThreads();
+
+ // Check that the regular session executes
+ try (OrtSession session = env.createSession(modelPath, options)) {
+ NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
+ long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
+ Object tensorData = OrtUtil.reshape(inputData, inputShape);
+ OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
+ container.put(inputMeta.getName(), tensor);
+ try (OrtSession.Result result = session.run(container)) {
+ OnnxValue resultTensor = result.get(0);
+ float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
+ assertEquals(expectedOutput.length, resultArray.length);
+ assertArrayEquals(expectedOutput, resultArray, 1e-6f);
+ }
+ container.clear();
+ tensor.close();
+ }
+
+ // Check that the session using the env thread pool executes
+ try (OrtSession session = env.createSession(modelPath, disableThreadOptions)) {
+ NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
+ long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
+ Object tensorData = OrtUtil.reshape(inputData, inputShape);
+ OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
+ container.put(inputMeta.getName(), tensor);
+ try (OrtSession.Result result = session.run(container)) {
+ OnnxValue resultTensor = result.get(0);
+ float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
+ assertEquals(expectedOutput.length, resultArray.length);
+ assertArrayEquals(expectedOutput, resultArray, 1e-6f);
+ }
+ container.clear();
+ tensor.close();
+ }
+ }
+
+ try {
+ OrtEnvironment newEnv =
+ OrtEnvironment.getEnvironment(
+ OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "fail", threadOpts);
+ // fail as we can't recreate environments with different threading options
+ fail("Should have thrown IllegalStateException");
+ } catch (IllegalStateException e) {
+ // pass
+ }
+
+ threadOpts.close();
+ }
+}
diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java
index b73583cab7441..6a6bfa0a7c6a3 100644
--- a/java/src/test/java/ai/onnxruntime/InferenceTest.java
+++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java
@@ -1,30 +1,22 @@
/*
- * Copyright (c) 2019, 2021, 2022, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
-import ai.onnxruntime.OnnxMl.TensorProto;
-import ai.onnxruntime.OnnxMl.TensorProto.DataType;
import ai.onnxruntime.OrtSession.Result;
import ai.onnxruntime.OrtSession.SessionOptions;
import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode;
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;
-import java.io.BufferedInputStream;
-import java.io.BufferedReader;
import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileReader;
import java.io.IOException;
-import java.io.InputStream;
-import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
@@ -44,7 +36,6 @@
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
-import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
@@ -57,76 +48,65 @@
/** Tests for the onnx-runtime Java interface. */
public class InferenceTest {
private static final Logger logger = Logger.getLogger(InferenceTest.class.getName());
- private static final Pattern LOAD_PATTERN = Pattern.compile("[,\\[\\] ]");
private static final String propertiesFile = "Properties.txt";
private static final Pattern inputPBPattern = Pattern.compile("input_*.pb");
private static final Pattern outputPBPattern = Pattern.compile("output_*.pb");
- private static Path getResourcePath(String path) {
- return new File(InferenceTest.class.getResource(path).getFile()).toPath();
- }
+ private static final OrtEnvironment env = OrtEnvironment.getEnvironment();
@Test
- public void repeatedCloseTest() throws OrtException {
- Logger.getLogger(OrtEnvironment.class.getName()).setLevel(Level.SEVERE);
- OrtEnvironment env = OrtEnvironment.getEnvironment("repeatedCloseTest");
- try (OrtEnvironment otherEnv = OrtEnvironment.getEnvironment()) {
- assertFalse(otherEnv.isClosed());
- }
- assertFalse(env.isClosed());
- env.close();
- assertTrue(env.isClosed());
+ public void environmentTest() {
+ // Checks that the environment instance is the same.
+ OrtEnvironment otherEnv = OrtEnvironment.getEnvironment();
+ assertSame(env, otherEnv);
+ otherEnv = OrtEnvironment.getEnvironment("test-name");
+ assertSame(env, otherEnv);
}
@Test
public void createSessionFromPath() throws OrtException {
- String modelPath = getResourcePath("/squeezenet.onnx").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath");
- OrtSession.SessionOptions options = new SessionOptions()) {
- try (OrtSession session = env.createSession(modelPath, options)) {
- assertNotNull(session);
- assertEquals(1, session.getNumInputs()); // 1 input node
- Map inputInfoList = session.getInputInfo();
- assertNotNull(inputInfoList);
- assertEquals(1, inputInfoList.size());
- NodeInfo input = inputInfoList.get("data_0");
- assertEquals("data_0", input.getName()); // input node name
- assertTrue(input.getInfo() instanceof TensorInfo);
- TensorInfo inputInfo = (TensorInfo) input.getInfo();
- assertEquals(OnnxJavaType.FLOAT, inputInfo.type);
- int[] expectedInputDimensions = new int[] {1, 3, 224, 224};
- assertEquals(expectedInputDimensions.length, inputInfo.shape.length);
- for (int i = 0; i < expectedInputDimensions.length; i++) {
- assertEquals(expectedInputDimensions[i], inputInfo.shape[i]);
- }
+ String modelPath = TestHelpers.getResourcePath("/squeezenet.onnx").toString();
+ try (OrtSession.SessionOptions options = new SessionOptions();
+ OrtSession session = env.createSession(modelPath, options)) {
+ assertNotNull(session);
+ assertEquals(1, session.getNumInputs()); // 1 input node
+ Map inputInfoList = session.getInputInfo();
+ assertNotNull(inputInfoList);
+ assertEquals(1, inputInfoList.size());
+ NodeInfo input = inputInfoList.get("data_0");
+ assertEquals("data_0", input.getName()); // input node name
+ assertTrue(input.getInfo() instanceof TensorInfo);
+ TensorInfo inputInfo = (TensorInfo) input.getInfo();
+ assertEquals(OnnxJavaType.FLOAT, inputInfo.type);
+ int[] expectedInputDimensions = new int[] {1, 3, 224, 224};
+ assertEquals(expectedInputDimensions.length, inputInfo.shape.length);
+ for (int i = 0; i < expectedInputDimensions.length; i++) {
+ assertEquals(expectedInputDimensions[i], inputInfo.shape[i]);
+ }
- assertEquals(1, session.getNumOutputs()); // 1 output node
- Map outputInfoList = session.getOutputInfo();
- assertNotNull(outputInfoList);
- assertEquals(1, outputInfoList.size());
- NodeInfo output = outputInfoList.get("softmaxout_1");
- assertEquals("softmaxout_1", output.getName()); // output node name
- assertTrue(output.getInfo() instanceof TensorInfo);
- TensorInfo outputInfo = (TensorInfo) output.getInfo();
- assertEquals(OnnxJavaType.FLOAT, outputInfo.type);
- int[] expectedOutputDimensions = new int[] {1, 1000, 1, 1};
- assertEquals(expectedOutputDimensions.length, outputInfo.shape.length);
- for (int i = 0; i < expectedOutputDimensions.length; i++) {
- assertEquals(expectedOutputDimensions[i], outputInfo.shape[i]);
- }
+ assertEquals(1, session.getNumOutputs()); // 1 output node
+ Map outputInfoList = session.getOutputInfo();
+ assertNotNull(outputInfoList);
+ assertEquals(1, outputInfoList.size());
+ NodeInfo output = outputInfoList.get("softmaxout_1");
+ assertEquals("softmaxout_1", output.getName()); // output node name
+ assertTrue(output.getInfo() instanceof TensorInfo);
+ TensorInfo outputInfo = (TensorInfo) output.getInfo();
+ assertEquals(OnnxJavaType.FLOAT, outputInfo.type);
+ int[] expectedOutputDimensions = new int[] {1, 1000, 1, 1};
+ assertEquals(expectedOutputDimensions.length, outputInfo.shape.length);
+ for (int i = 0; i < expectedOutputDimensions.length; i++) {
+ assertEquals(expectedOutputDimensions[i], outputInfo.shape[i]);
}
}
}
@Test
public void morePartialInputsTest() throws OrtException {
- String modelPath = getResourcePath("/partial-inputs-test-2.onnx").toString();
- try (OrtEnvironment env =
- OrtEnvironment.getEnvironment(
- OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
- OrtSession.SessionOptions options = new SessionOptions();
+ String modelPath = TestHelpers.getResourcePath("/partial-inputs-test-2.onnx").toString();
+ try (OrtSession.SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
assertNotNull(session);
assertEquals(3, session.getNumInputs());
@@ -209,11 +189,8 @@ public void morePartialInputsTest() throws OrtException {
@Test
public void partialInputsTest() throws OrtException {
- String modelPath = getResourcePath("/partial-inputs-test.onnx").toString();
- try (OrtEnvironment env =
- OrtEnvironment.getEnvironment(
- OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs");
- OrtSession.SessionOptions options = new SessionOptions();
+ String modelPath = TestHelpers.getResourcePath("/partial-inputs-test.onnx").toString();
+ try (OrtSession.SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
assertNotNull(session);
assertEquals(3, session.getNumInputs());
@@ -347,122 +324,50 @@ public void partialInputsTest() throws OrtException {
@Test
public void createSessionFromByteArray() throws IOException, OrtException {
- Path modelPath = getResourcePath("/squeezenet.onnx");
+ Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx");
byte[] modelBytes = Files.readAllBytes(modelPath);
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromByteArray");
- OrtSession.SessionOptions options = new SessionOptions()) {
- try (OrtSession session = env.createSession(modelBytes, options)) {
- assertNotNull(session);
- assertEquals(1, session.getNumInputs()); // 1 input node
- Map inputInfoList = session.getInputInfo();
- assertNotNull(inputInfoList);
- assertEquals(1, inputInfoList.size());
- NodeInfo input = inputInfoList.get("data_0");
- assertEquals("data_0", input.getName()); // input node name
- assertTrue(input.getInfo() instanceof TensorInfo);
- TensorInfo inputInfo = (TensorInfo) input.getInfo();
- assertEquals(OnnxJavaType.FLOAT, inputInfo.type);
- int[] expectedInputDimensions = new int[] {1, 3, 224, 224};
- assertEquals(expectedInputDimensions.length, inputInfo.shape.length);
- for (int i = 0; i < expectedInputDimensions.length; i++) {
- assertEquals(expectedInputDimensions[i], inputInfo.shape[i]);
- }
-
- assertEquals(1, session.getNumOutputs()); // 1 output node
- Map outputInfoList = session.getOutputInfo();
- assertNotNull(outputInfoList);
- assertEquals(1, outputInfoList.size());
- NodeInfo output = outputInfoList.get("softmaxout_1");
- assertEquals("softmaxout_1", output.getName()); // output node name
- assertTrue(output.getInfo() instanceof TensorInfo);
- TensorInfo outputInfo = (TensorInfo) output.getInfo();
- assertEquals(OnnxJavaType.FLOAT, outputInfo.type);
- int[] expectedOutputDimensions = new int[] {1, 1000, 1, 1};
- assertEquals(expectedOutputDimensions.length, outputInfo.shape.length);
- for (int i = 0; i < expectedOutputDimensions.length; i++) {
- assertEquals(expectedOutputDimensions[i], outputInfo.shape[i]);
- }
-
- // Check the metadata can be extracted
- OnnxModelMetadata metadata = session.getMetadata();
- assertEquals("onnx-caffe2", metadata.getProducerName());
- assertEquals("squeezenet_old", metadata.getGraphName());
- assertEquals("", metadata.getDomain());
- assertEquals("", metadata.getDescription());
- assertEquals(0x7FFFFFFFFFFFFFFFL, metadata.getVersion());
- assertTrue(metadata.getCustomMetadata().isEmpty());
- }
- }
- }
-
- @Test
- public void environmentThreadPoolTest() throws OrtException {
- Path squeezeNet = getResourcePath("/squeezenet.onnx");
- String modelPath = squeezeNet.toString();
- float[] inputData = loadTensorFromFile(getResourcePath("/bench.in"));
- float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out"));
- Map container = new HashMap<>();
-
- OrtEnvironment.ThreadingOptions threadOpts = new OrtEnvironment.ThreadingOptions();
- threadOpts.setGlobalInterOpNumThreads(2);
- threadOpts.setGlobalIntraOpNumThreads(2);
- threadOpts.setGlobalDenormalAsZero();
- threadOpts.setGlobalSpinControl(true);
- try (OrtEnvironment env =
- OrtEnvironment.getEnvironment(
- OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "environmentThreadPoolTest", threadOpts);
- OrtSession.SessionOptions options = new SessionOptions();
- OrtSession.SessionOptions disableThreadOptions = new SessionOptions()) {
- disableThreadOptions.disablePerSessionThreads();
-
- // Check that the regular session executes
- try (OrtSession session = env.createSession(modelPath, options)) {
- NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
- long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
- Object tensorData = OrtUtil.reshape(inputData, inputShape);
- OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
- container.put(inputMeta.getName(), tensor);
- try (OrtSession.Result result = session.run(container)) {
- OnnxValue resultTensor = result.get(0);
- float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
- assertEquals(expectedOutput.length, resultArray.length);
- assertArrayEquals(expectedOutput, resultArray, 1e-6f);
- }
- container.clear();
- tensor.close();
+ try (OrtSession.SessionOptions options = new SessionOptions();
+ OrtSession session = env.createSession(modelBytes, options)) {
+ assertNotNull(session);
+ assertEquals(1, session.getNumInputs()); // 1 input node
+ Map inputInfoList = session.getInputInfo();
+ assertNotNull(inputInfoList);
+ assertEquals(1, inputInfoList.size());
+ NodeInfo input = inputInfoList.get("data_0");
+ assertEquals("data_0", input.getName()); // input node name
+ assertTrue(input.getInfo() instanceof TensorInfo);
+ TensorInfo inputInfo = (TensorInfo) input.getInfo();
+ assertEquals(OnnxJavaType.FLOAT, inputInfo.type);
+ int[] expectedInputDimensions = new int[] {1, 3, 224, 224};
+ assertEquals(expectedInputDimensions.length, inputInfo.shape.length);
+ for (int i = 0; i < expectedInputDimensions.length; i++) {
+ assertEquals(expectedInputDimensions[i], inputInfo.shape[i]);
}
- // Check that the session using the env thread pool executes
- try (OrtSession session = env.createSession(modelPath, disableThreadOptions)) {
- NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
- long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
- Object tensorData = OrtUtil.reshape(inputData, inputShape);
- OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
- container.put(inputMeta.getName(), tensor);
- try (OrtSession.Result result = session.run(container)) {
- OnnxValue resultTensor = result.get(0);
- float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
- assertEquals(expectedOutput.length, resultArray.length);
- assertArrayEquals(expectedOutput, resultArray, 1e-6f);
- }
- container.clear();
- tensor.close();
+ assertEquals(1, session.getNumOutputs()); // 1 output node
+ Map outputInfoList = session.getOutputInfo();
+ assertNotNull(outputInfoList);
+ assertEquals(1, outputInfoList.size());
+ NodeInfo output = outputInfoList.get("softmaxout_1");
+ assertEquals("softmaxout_1", output.getName()); // output node name
+ assertTrue(output.getInfo() instanceof TensorInfo);
+ TensorInfo outputInfo = (TensorInfo) output.getInfo();
+ assertEquals(OnnxJavaType.FLOAT, outputInfo.type);
+ int[] expectedOutputDimensions = new int[] {1, 1000, 1, 1};
+ assertEquals(expectedOutputDimensions.length, outputInfo.shape.length);
+ for (int i = 0; i < expectedOutputDimensions.length; i++) {
+ assertEquals(expectedOutputDimensions[i], outputInfo.shape[i]);
}
- }
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("test")) {
- try {
- OrtEnvironment newEnv =
- OrtEnvironment.getEnvironment(
- OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "fail", threadOpts);
- // fail as we can't recreate environments with different threading options
- fail("Should have thrown IllegalStateException");
- } catch (IllegalStateException e) {
- // pass
- }
+ // Check the metadata can be extracted
+ OnnxModelMetadata metadata = session.getMetadata();
+ assertEquals("onnx-caffe2", metadata.getProducerName());
+ assertEquals("squeezenet_old", metadata.getGraphName());
+ assertEquals("", metadata.getDomain());
+ assertEquals("", metadata.getDescription());
+ assertEquals(0x7FFFFFFFFFFFFFFFL, metadata.getVersion());
+ assertTrue(metadata.getCustomMetadata().isEmpty());
}
-
- threadOpts.close();
}
@Test
@@ -475,11 +380,10 @@ public void inferenceTest() throws OrtException {
private void canRunInferenceOnAModel(OptLevel graphOptimizationLevel, ExecutionMode exectionMode)
throws OrtException {
- String modelPath = getResourcePath("/squeezenet.onnx").toString();
+ String modelPath = TestHelpers.getResourcePath("/squeezenet.onnx").toString();
// Set the graph optimization level for this session.
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("canRunInferenceOnAModel");
- SessionOptions options = new SessionOptions()) {
+ try (SessionOptions options = new SessionOptions()) {
options.setOptimizationLevel(graphOptimizationLevel);
options.setExecutionMode(exectionMode);
@@ -488,7 +392,8 @@ private void canRunInferenceOnAModel(OptLevel graphOptimizationLevel, ExecutionM
Map container = new HashMap<>();
NodeInfo inputMeta = inputMetaMap.values().iterator().next();
- float[] inputData = loadTensorFromFile(getResourcePath("/bench.in"));
+ float[] inputData =
+ TestHelpers.loadTensorFromFile(TestHelpers.getResourcePath("/bench.in"));
// this is the data for only one input tensor for this model
Object tensorData =
OrtUtil.reshape(inputData, ((TensorInfo) inputMeta.getInfo()).getShape());
@@ -499,7 +404,8 @@ private void canRunInferenceOnAModel(OptLevel graphOptimizationLevel, ExecutionM
try (OrtSession.Result results = session.run(container)) {
assertEquals(1, results.size());
- float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out"));
+ float[] expectedOutput =
+ TestHelpers.loadTensorFromFile(TestHelpers.getResourcePath("/bench.expected_out"));
// validate the results
// Only iterates once
for (Map.Entry r : results) {
@@ -529,8 +435,7 @@ private void canRunInferenceOnAModel(OptLevel graphOptimizationLevel, ExecutionM
@Test
public void throwWrongInputName() throws OrtException {
SqueezeNetTuple tuple = openSessionSqueezeNet();
- try (OrtEnvironment env = tuple.env;
- OrtSession session = tuple.session) {
+ try (OrtSession session = tuple.session) {
float[] inputData = tuple.inputData;
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
Map container = new HashMap<>();
@@ -556,8 +461,7 @@ public void throwWrongInputName() throws OrtException {
@Test
public void throwWrongInputType() throws OrtException {
SqueezeNetTuple tuple = openSessionSqueezeNet();
- try (OrtEnvironment env = tuple.env;
- OrtSession session = tuple.session) {
+ try (OrtSession session = tuple.session) {
float[] inputData = tuple.inputData;
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
@@ -584,8 +488,7 @@ public void throwWrongInputType() throws OrtException {
@Test
public void throwExtraInputs() throws OrtException {
SqueezeNetTuple tuple = openSessionSqueezeNet();
- try (OrtEnvironment env = tuple.env;
- OrtSession session = tuple.session) {
+ try (OrtSession session = tuple.session) {
float[] inputData = tuple.inputData;
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
@@ -615,8 +518,7 @@ public void testMultiThreads() throws OrtException, InterruptedException {
int numThreads = 10;
int loop = 10;
SqueezeNetTuple tuple = openSessionSqueezeNet();
- try (OrtEnvironment env = tuple.env;
- OrtSession session = tuple.session) {
+ try (OrtSession session = tuple.session) {
float[] inputData = tuple.inputData;
float[] expectedOutput = tuple.outputData;
@@ -664,25 +566,23 @@ public void testProviders() {
@Test
public void testSymbolicDimensionAssignment() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/capi_symbolic_dims.onnx").toString();
+ String modelPath = TestHelpers.getResourcePath("/capi_symbolic_dims.onnx").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testSymbolicDimensionAssignment")) {
- // Check the dimension is symbolic
- try (SessionOptions options = new SessionOptions()) {
- try (OrtSession session = env.createSession(modelPath, options)) {
- Map infoMap = session.getInputInfo();
- TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
- assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
- }
+ // Check the dimension is symbolic
+ try (SessionOptions options = new SessionOptions()) {
+ try (OrtSession session = env.createSession(modelPath, options)) {
+ Map infoMap = session.getInputInfo();
+ TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
+ assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
}
- // Check that when the options are assigned it overrides the symbolic dimension
- try (SessionOptions options = new SessionOptions()) {
- options.setSymbolicDimensionValue("n", 5);
- try (OrtSession session = env.createSession(modelPath, options)) {
- Map infoMap = session.getInputInfo();
- TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
- assertArrayEquals(new long[] {5, 2}, aInfo.shape);
- }
+ }
+ // Check that when the options are assigned it overrides the symbolic dimension
+ try (SessionOptions options = new SessionOptions()) {
+ options.setSymbolicDimensionValue("n", 5);
+ try (OrtSession session = env.createSession(modelPath, options)) {
+ Map infoMap = session.getInputInfo();
+ TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
+ assertArrayEquals(new long[] {5, 2}, aInfo.shape);
}
}
}
@@ -723,8 +623,7 @@ private void runProvider(OrtProvider provider) throws OrtException {
assertTrue(providers.contains(OrtProvider.CPU));
assertTrue(providers.contains(provider));
SqueezeNetTuple tuple = openSessionSqueezeNet(EnumSet.of(provider));
- try (OrtEnvironment env = tuple.env;
- OrtSession session = tuple.session) {
+ try (OrtSession session = tuple.session) {
float[] inputData = tuple.inputData;
float[] expectedOutput = tuple.outputData;
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
@@ -853,8 +752,7 @@ public void testPreTrainedModel(String opset, String modelName) throws IOExcepti
+ modelNamesList);
}
- try (OrtEnvironment env = OrtEnvironment.getEnvironment();
- OrtSession session = env.createSession(onnxModelFileName)) {
+ try (OrtSession session = env.createSession(onnxModelFileName)) {
String testDataDirNamePattern;
if (opset.equals("opset9") && modelName.equals("LSTM_Seq_lens_unpacked")) {
testDataDirNamePattern = "seq_lens"; // discrepency in data directory
@@ -869,12 +767,12 @@ public void testPreTrainedModel(String opset, String modelName) throws IOExcepti
Map outputContainer = new HashMap<>();
for (File f :
testDataDir.listFiles((dir, name) -> inputPBPattern.matcher(name).matches())) {
- StringTensorPair o = loadTensorFromFilePb(env, f, inMeta);
+ TestHelpers.StringTensorPair o = TestHelpers.loadTensorFromFilePb(env, f, inMeta);
inputContainer.put(o.string, o.tensor);
}
for (File f :
testDataDir.listFiles((dir, name) -> outputPBPattern.matcher(name).matches())) {
- StringTensorPair o = loadTensorFromFilePb(env, f, outMeta);
+ TestHelpers.StringTensorPair o = TestHelpers.loadTensorFromFilePb(env, f, outMeta);
outputContainer.put(o.string, o.tensor);
}
@@ -915,10 +813,9 @@ public void testPreTrainedModel(String opset, String modelName) throws IOExcepti
@Test
public void testModelInputFLOAT() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_FLOAT.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_FLOAT.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputFLOAT");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
long[] shape = new long[] {1, 5};
@@ -942,10 +839,9 @@ public void testModelInputFLOAT() throws OrtException {
@Test
public void testModelInputBuffer() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_FLOAT.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_FLOAT.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputFLOAT");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
long[] shape = new long[] {1, 5};
@@ -1003,10 +899,9 @@ public void testModelInputBuffer() throws OrtException {
@Test
public void testRunOptions() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_BOOL.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_BOOL.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testRunOptions");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options);
OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) {
runOptions.setRunTag("monkeys");
@@ -1035,73 +930,71 @@ public void testRunOptions() throws OrtException {
@Test
public void testExtraSessionOptions() throws OrtException, IOException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_BOOL.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_BOOL.pb").toString();
File tmpPath = File.createTempFile("onnx-runtime-profiling", "file");
tmpPath.deleteOnExit();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testExtraSessionOptions")) {
- try (SessionOptions options = new SessionOptions()) {
- options.setCPUArenaAllocator(true);
- options.setMemoryPatternOptimization(true);
- options.enableProfiling(tmpPath.getAbsolutePath());
- options.setLoggerId("monkeys");
- options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
- options.setSessionLogVerbosityLevel(5);
- Map configEntries = options.getConfigEntries();
- assertTrue(configEntries.isEmpty());
- options.addConfigEntry("key", "value");
- assertEquals("value", configEntries.get("key"));
- try {
- options.addConfigEntry("", "invalid key");
- fail("Add config entry with empty key should have failed");
- } catch (OrtException e) {
- assertTrue(e.getMessage().contains("Config key is empty"));
- assertEquals(OrtException.OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode());
+ try (SessionOptions options = new SessionOptions()) {
+ options.setCPUArenaAllocator(true);
+ options.setMemoryPatternOptimization(true);
+ options.enableProfiling(tmpPath.getAbsolutePath());
+ options.setLoggerId("monkeys");
+ options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
+ options.setSessionLogVerbosityLevel(5);
+ Map configEntries = options.getConfigEntries();
+ assertTrue(configEntries.isEmpty());
+ options.addConfigEntry("key", "value");
+ assertEquals("value", configEntries.get("key"));
+ try {
+ options.addConfigEntry("", "invalid key");
+ fail("Add config entry with empty key should have failed");
+ } catch (OrtException e) {
+ assertTrue(e.getMessage().contains("Config key is empty"));
+ assertEquals(OrtException.OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode());
+ }
+ try (OrtSession session = env.createSession(modelPath, options)) {
+ String inputName = session.getInputNames().iterator().next();
+ Map container = new HashMap<>();
+ boolean[] flatInput = new boolean[] {true, false, true, false, true};
+ Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
+ OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
+ container.put(inputName, ov);
+ try (OrtSession.Result res = session.run(container)) {
+ boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
+ assertArrayEquals(flatInput, resultArray);
}
- try (OrtSession session = env.createSession(modelPath, options)) {
- String inputName = session.getInputNames().iterator().next();
- Map container = new HashMap<>();
- boolean[] flatInput = new boolean[] {true, false, true, false, true};
- Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
- OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
- container.put(inputName, ov);
- try (OrtSession.Result res = session.run(container)) {
- boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
- assertArrayEquals(flatInput, resultArray);
- }
- // Check that the profiling start time doesn't throw
- long profilingStartTime = session.getProfilingStartTimeInNs();
+ // Check that the profiling start time doesn't throw
+ long profilingStartTime = session.getProfilingStartTimeInNs();
- // Check the profiling output doesn't throw
- String profilingOutput = session.endProfiling();
- File profilingOutputFile = new File(profilingOutput);
- profilingOutputFile.deleteOnExit();
- try (OrtSession.Result res = session.run(container)) {
- boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
- assertArrayEquals(flatInput, resultArray);
- }
- OnnxValue.close(container);
+ // Check the profiling output doesn't throw
+ String profilingOutput = session.endProfiling();
+ File profilingOutputFile = new File(profilingOutput);
+ profilingOutputFile.deleteOnExit();
+ try (OrtSession.Result res = session.run(container)) {
+ boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
+ assertArrayEquals(flatInput, resultArray);
}
+ OnnxValue.close(container);
}
- try (SessionOptions options = new SessionOptions()) {
- options.setCPUArenaAllocator(false);
- options.setMemoryPatternOptimization(false);
- options.enableProfiling(tmpPath.getAbsolutePath());
- options.disableProfiling();
- options.setSessionLogVerbosityLevel(0);
- try (OrtSession session = env.createSession(modelPath, options)) {
- String inputName = session.getInputNames().iterator().next();
- Map container = new HashMap<>();
- boolean[] flatInput = new boolean[] {true, false, true, false, true};
- Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
- OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
- container.put(inputName, ov);
- try (OrtSession.Result res = session.run(container)) {
- boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
- assertArrayEquals(flatInput, resultArray);
- }
- OnnxValue.close(container);
+ }
+ try (SessionOptions options = new SessionOptions()) {
+ options.setCPUArenaAllocator(false);
+ options.setMemoryPatternOptimization(false);
+ options.enableProfiling(tmpPath.getAbsolutePath());
+ options.disableProfiling();
+ options.setSessionLogVerbosityLevel(0);
+ try (OrtSession session = env.createSession(modelPath, options)) {
+ String inputName = session.getInputNames().iterator().next();
+ Map container = new HashMap<>();
+ boolean[] flatInput = new boolean[] {true, false, true, false, true};
+ Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5});
+ OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn);
+ container.put(inputName, ov);
+ try (OrtSession.Result res = session.run(container)) {
+ boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue());
+ assertArrayEquals(flatInput, resultArray);
}
+ OnnxValue.close(container);
}
}
}
@@ -1115,19 +1008,18 @@ public void testLoadCustomLibrary() throws OrtException {
if (osName.contains("windows")) {
// In windows we start in the wrong working directory relative to the custom_op_library.dll
// So we look it up as a classpath resource and resolve it to a real path
- customLibraryName = getResourcePath("/custom_op_library.dll").toString();
+ customLibraryName = TestHelpers.getResourcePath("/custom_op_library.dll").toString();
} else if (osName.contains("mac")) {
- customLibraryName = getResourcePath("/libcustom_op_library.dylib").toString();
+ customLibraryName = TestHelpers.getResourcePath("/libcustom_op_library.dylib").toString();
} else if (osName.contains("linux")) {
- customLibraryName = getResourcePath("/libcustom_op_library.so").toString();
+ customLibraryName = TestHelpers.getResourcePath("/libcustom_op_library.so").toString();
} else {
fail("Unknown os/platform '" + osName + "'");
}
String customOpLibraryTestModel =
- getResourcePath("/custom_op_library/custom_op_test.onnx").toString();
+ TestHelpers.getResourcePath("/custom_op_library/custom_op_test.onnx").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testLoadCustomLibrary");
- SessionOptions options = new SessionOptions()) {
+ try (SessionOptions options = new SessionOptions()) {
options.registerCustomOpLibrary(customLibraryName);
if (OnnxRuntime.extractCUDA()) {
options.addCUDA();
@@ -1170,41 +1062,39 @@ public void testLoadCustomLibrary() throws OrtException {
@Test
public void testModelMetadata() throws OrtException {
- String modelPath = getResourcePath("/model_with_valid_ort_config_json.onnx").toString();
+ String modelPath =
+ TestHelpers.getResourcePath("/model_with_valid_ort_config_json.onnx").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelMetadata")) {
- try (OrtSession session = env.createSession(modelPath)) {
- OnnxModelMetadata modelMetadata = session.getMetadata();
+ try (OrtSession session = env.createSession(modelPath)) {
+ OnnxModelMetadata modelMetadata = session.getMetadata();
- Assertions.assertEquals(1, modelMetadata.getVersion());
+ Assertions.assertEquals(1, modelMetadata.getVersion());
- Assertions.assertEquals("Hari", modelMetadata.getProducerName());
+ Assertions.assertEquals("Hari", modelMetadata.getProducerName());
- Assertions.assertEquals("matmul test", modelMetadata.getGraphName());
+ Assertions.assertEquals("matmul test", modelMetadata.getGraphName());
- Assertions.assertEquals("", modelMetadata.getDomain());
+ Assertions.assertEquals("", modelMetadata.getDomain());
- Assertions.assertEquals(
- "This is a test model with a valid ORT config Json", modelMetadata.getDescription());
+ Assertions.assertEquals(
+ "This is a test model with a valid ORT config Json", modelMetadata.getDescription());
- Assertions.assertEquals("graph description", modelMetadata.getGraphDescription());
+ Assertions.assertEquals("graph description", modelMetadata.getGraphDescription());
- Assertions.assertEquals(2, modelMetadata.getCustomMetadata().size());
- Assertions.assertEquals("dummy_value", modelMetadata.getCustomMetadata().get("dummy_key"));
- Assertions.assertEquals(
- "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}",
- modelMetadata.getCustomMetadata().get("ort_config"));
- }
+ Assertions.assertEquals(2, modelMetadata.getCustomMetadata().size());
+ Assertions.assertEquals("dummy_value", modelMetadata.getCustomMetadata().get("dummy_key"));
+ Assertions.assertEquals(
+ "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}",
+ modelMetadata.getCustomMetadata().get("ort_config"));
}
}
@Test
public void testModelInputBOOL() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_BOOL.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_BOOL.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputBOOL");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map container = new HashMap<>();
@@ -1223,10 +1113,9 @@ public void testModelInputBOOL() throws OrtException {
@Test
public void testModelInputINT32() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_INT32.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_INT32.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT32");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map container = new HashMap<>();
@@ -1245,10 +1134,9 @@ public void testModelInputINT32() throws OrtException {
@Test
public void testModelInputDOUBLE() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_DOUBLE.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_DOUBLE.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputDOUBLE");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map container = new HashMap<>();
@@ -1267,10 +1155,9 @@ public void testModelInputDOUBLE() throws OrtException {
@Test
public void testModelInputINT8() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_INT8.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_INT8.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT8");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map container = new HashMap<>();
@@ -1288,10 +1175,9 @@ public void testModelInputINT8() throws OrtException {
@Test
public void testModelInputUINT8() throws OrtException {
- String modelPath = getResourcePath("/test_types_UINT8.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_UINT8.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputUINT8");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map container = new HashMap<>();
@@ -1311,10 +1197,9 @@ public void testModelInputUINT8() throws OrtException {
@Test
public void testModelInputINT16() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_INT16.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_INT16.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT16");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map container = new HashMap<>();
@@ -1333,10 +1218,9 @@ public void testModelInputINT16() throws OrtException {
@Test
public void testModelInputINT64() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
- String modelPath = getResourcePath("/test_types_INT64.pb").toString();
+ String modelPath = TestHelpers.getResourcePath("/test_types_INT64.pb").toString();
- try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT64");
- SessionOptions options = new SessionOptions();
+ try (SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelPath, options)) {
String inputName = session.getInputNames().iterator().next();
Map container = new HashMap<>();
@@ -1360,9 +1244,8 @@ public void testModelSequenceOfMapIntFloat() throws OrtException {
// "probabilities" is a sequence