diff --git a/java/build.gradle b/java/build.gradle index fa76ae927fe94..db4aa50d84627 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -167,6 +167,7 @@ test { java { dependsOn spotlessJava } + forkEvery 1 // Forces each test class to be run in a separate JVM, which is necessary for testing the environment thread pool useJUnitPlatform() if (cmakeBuildDir != null) { workingDir cmakeBuildDir diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index fa199c9389c19..f71d26c8bfffa 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -333,7 +333,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, Object data) throws Or */ static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Object data) throws OrtException { - if ((!env.isClosed()) && (!allocator.isClosed())) { + if (!allocator.isClosed()) { TensorInfo info = TensorInfo.constructFromJavaArray(data); if (info.type == OnnxJavaType.STRING) { if (info.shape.length == 0) { @@ -403,7 +403,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, String[] data, long[] */ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, String[] data, long[] shape) throws OrtException { - if ((!env.isClosed()) && (!allocator.isClosed())) { + if (!allocator.isClosed()) { TensorInfo info = new TensorInfo( shape, @@ -451,7 +451,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, FloatBuffer data, long static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, FloatBuffer data, long[] shape) throws OrtException { - if ((!env.isClosed()) && (!allocator.isClosed())) { + if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.FLOAT; return createTensor(type, allocator, data, shape); } else { @@ -492,7 +492,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, DoubleBuffer data, lon static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, DoubleBuffer data, long[] shape) throws OrtException { - if ((!env.isClosed()) && (!allocator.isClosed())) { + if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.DOUBLE; return createTensor(type, allocator, data, shape); } else { @@ -571,7 +571,7 @@ public static OnnxTensor createTensor( static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, ByteBuffer data, long[] shape, OnnxJavaType type) throws OrtException { - if ((!env.isClosed()) && (!allocator.isClosed())) { + if (!allocator.isClosed()) { return createTensor(type, allocator, data, shape); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); @@ -611,7 +611,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape) throws OrtException { - if ((!env.isClosed()) && (!allocator.isClosed())) { + if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.INT16; return createTensor(type, allocator, data, shape); } else { @@ -652,7 +652,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, IntBuffer data, long[] static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, IntBuffer data, long[] shape) throws OrtException { - if ((!env.isClosed()) && (!allocator.isClosed())) { + if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.INT32; return createTensor(type, allocator, data, shape); } else { @@ -693,7 +693,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, LongBuffer data, long[ static OnnxTensor createTensor( OrtEnvironment env, OrtAllocator allocator, LongBuffer data, long[] shape) throws OrtException { - if ((!env.isClosed()) && (!allocator.isClosed())) { + if (!allocator.isClosed()) { OnnxJavaType type = OnnxJavaType.INT64; return createTensor(type, allocator, data, shape); } else { diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index dbf47ae027ead..b5f1fb21af28d 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2021 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -7,14 +7,18 @@ import ai.onnxruntime.OrtSession.SessionOptions; import java.io.IOException; import java.util.EnumSet; -import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Logger; /** * The host object for the onnx-runtime system. Can create {@link OrtSession}s which encapsulate * specific models. + * + *

There can be at most one OrtEnvironment object created in a JVM lifetime. This class + * implements {@link AutoCloseable} as before for backwards compatibility with 1.10 and earlier, but + * the {@link #close} method is a no-op. The environment is closed by a JVM shutdown hook registered + * on construction. */ -public class OrtEnvironment implements AutoCloseable { +public final class OrtEnvironment implements AutoCloseable { private static final Logger logger = Logger.getLogger(OrtEnvironment.class.getName()); @@ -30,8 +34,6 @@ public class OrtEnvironment implements AutoCloseable { private static volatile OrtEnvironment INSTANCE; - private static final AtomicInteger refCount = new AtomicInteger(); - private static volatile OrtLoggingLevel curLogLevel; private static volatile String curLoggingName; @@ -47,8 +49,6 @@ public static synchronized OrtEnvironment getEnvironment() { // If there's no instance, create one. return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME); } else { - // else return the current one. - refCount.incrementAndGet(); return INSTANCE; } } @@ -106,7 +106,6 @@ public static synchronized OrtEnvironment getEnvironment( "Tried to change OrtEnvironment's logging level or name while a reference exists."); } } - refCount.incrementAndGet(); return INSTANCE; } @@ -131,7 +130,6 @@ public static synchronized OrtEnvironment getEnvironment( } catch (OrtException e) { throw new IllegalStateException("Failed to create OrtEnvironment", e); } - refCount.incrementAndGet(); return INSTANCE; } else { // As the thread pool state is unknown, and that's probably not what the user wanted. @@ -144,8 +142,6 @@ public static synchronized OrtEnvironment getEnvironment( final OrtAllocator defaultAllocator; - private volatile boolean closed = false; - /** * Create an OrtEnvironment using a default name. * @@ -165,6 +161,8 @@ private OrtEnvironment() throws OrtException { private OrtEnvironment(OrtLoggingLevel loggingLevel, String name) throws OrtException { nativeHandle = createHandle(OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name); defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true); + Runtime.getRuntime() + .addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle))); } /** @@ -181,6 +179,8 @@ private OrtEnvironment(OrtLoggingLevel loggingLevel, String name, ThreadingOptio createHandle( OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name, threadOptions.nativeHandle); defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true); + Runtime.getRuntime() + .addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle))); } /** @@ -219,11 +219,7 @@ public OrtSession createSession(String modelPath, SessionOptions options) throws */ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOptions options) throws OrtException { - if (!closed) { - return new OrtSession(this, modelPath, allocator, options); - } else { - throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment."); - } + return new OrtSession(this, modelPath, allocator, options); } /** @@ -262,11 +258,7 @@ public OrtSession createSession(byte[] modelArray) throws OrtException { */ OrtSession createSession(byte[] modelArray, OrtAllocator allocator, SessionOptions options) throws OrtException { - if (!closed) { - return new OrtSession(this, modelArray, allocator, options); - } else { - throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment."); - } + return new OrtSession(this, modelArray, allocator, options); } /** @@ -279,41 +271,11 @@ public void setTelemetry(boolean sendTelemetry) throws OrtException { setTelemetry(OnnxRuntime.ortApiHandle, nativeHandle, sendTelemetry); } - /** - * Is this environment closed? - * - * @return True if the environment is closed. - */ - public boolean isClosed() { - return closed; - } - @Override public String toString() { return "OrtEnvironment(name=" + curLoggingName + ",logLevel=" + curLogLevel + ")"; } - /** - * Closes the OrtEnvironment. If this is the last reference to the environment then it closes the - * native handle. - * - * @throws OrtException If the close failed. - */ - @Override - public synchronized void close() throws OrtException { - synchronized (refCount) { - int curCount = refCount.get(); - if (curCount != 0) { - refCount.decrementAndGet(); - } - if (curCount == 1) { - close(OnnxRuntime.ortApiHandle, nativeHandle); - closed = true; - INSTANCE = null; - } - } - } - /** * Gets the providers available in this environment. * @@ -377,11 +339,22 @@ private static native long createHandle( private static native void setTelemetry(long apiHandle, long nativeHandle, boolean sendTelemetry) throws OrtException; + /** Close is a no-op on OrtEnvironment since ORT 1.11. */ + @Override + public void close() {} + /** * Controls the global thread pools in the environment. Only used if the session is constructed * using an options with {@link OrtSession.SessionOptions#disablePerSessionThreads()} set. */ public static final class ThreadingOptions implements AutoCloseable { + static { + try { + OnnxRuntime.init(); + } catch (IOException e) { + throw new RuntimeException("Failed to load onnx-runtime library", e); + } + } private final long nativeHandle; @@ -486,4 +459,24 @@ private native void setGlobalDenormalAsZero(long apiHandle, long nativeHandle) private native void closeThreadingOptions(long apiHandle, long nativeHandle); } + + private static final class OrtEnvCloser implements Runnable { + + private final long apiHandle; + private final long nativeHandle; + + OrtEnvCloser(long apiHandle, long nativeHandle) { + this.apiHandle = apiHandle; + this.nativeHandle = nativeHandle; + } + + @Override + public void run() { + try { + OrtEnvironment.close(apiHandle, nativeHandle); + } catch (OrtException e) { + System.err.println("Error closing OrtEnvironment, " + e); + } + } + } } diff --git a/java/src/test/java/ai/onnxruntime/EnvironmentThreadPoolTest.java b/java/src/test/java/ai/onnxruntime/EnvironmentThreadPoolTest.java new file mode 100644 index 0000000000000..514976723840b --- /dev/null +++ b/java/src/test/java/ai/onnxruntime/EnvironmentThreadPoolTest.java @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import static ai.onnxruntime.TestHelpers.getResourcePath; +import static ai.onnxruntime.TestHelpers.loadTensorFromFile; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; + +/** This test is in a separate class to ensure it is run in a clean JVM. */ +public class EnvironmentThreadPoolTest { + + @Test + public void environmentThreadPoolTest() throws OrtException { + Path squeezeNet = getResourcePath("/squeezenet.onnx"); + String modelPath = squeezeNet.toString(); + float[] inputData = loadTensorFromFile(getResourcePath("/bench.in")); + float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out")); + Map container = new HashMap<>(); + + OrtEnvironment.ThreadingOptions threadOpts = new OrtEnvironment.ThreadingOptions(); + threadOpts.setGlobalInterOpNumThreads(2); + threadOpts.setGlobalIntraOpNumThreads(2); + threadOpts.setGlobalDenormalAsZero(); + threadOpts.setGlobalSpinControl(true); + OrtEnvironment env = + OrtEnvironment.getEnvironment( + OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "environmentThreadPoolTest", threadOpts); + try (OrtSession.SessionOptions options = new OrtSession.SessionOptions(); + OrtSession.SessionOptions disableThreadOptions = new OrtSession.SessionOptions()) { + disableThreadOptions.disablePerSessionThreads(); + + // Check that the regular session executes + try (OrtSession session = env.createSession(modelPath, options)) { + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape; + Object tensorData = OrtUtil.reshape(inputData, inputShape); + OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData); + container.put(inputMeta.getName(), tensor); + try (OrtSession.Result result = session.run(container)) { + OnnxValue resultTensor = result.get(0); + float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); + assertEquals(expectedOutput.length, resultArray.length); + assertArrayEquals(expectedOutput, resultArray, 1e-6f); + } + container.clear(); + tensor.close(); + } + + // Check that the session using the env thread pool executes + try (OrtSession session = env.createSession(modelPath, disableThreadOptions)) { + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape; + Object tensorData = OrtUtil.reshape(inputData, inputShape); + OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData); + container.put(inputMeta.getName(), tensor); + try (OrtSession.Result result = session.run(container)) { + OnnxValue resultTensor = result.get(0); + float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); + assertEquals(expectedOutput.length, resultArray.length); + assertArrayEquals(expectedOutput, resultArray, 1e-6f); + } + container.clear(); + tensor.close(); + } + } + + try { + OrtEnvironment newEnv = + OrtEnvironment.getEnvironment( + OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "fail", threadOpts); + // fail as we can't recreate environments with different threading options + fail("Should have thrown IllegalStateException"); + } catch (IllegalStateException e) { + // pass + } + + threadOpts.close(); + } +} diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index b73583cab7441..6a6bfa0a7c6a3 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1,30 +1,22 @@ /* - * Copyright (c) 2019, 2021, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -import ai.onnxruntime.OnnxMl.TensorProto; -import ai.onnxruntime.OnnxMl.TensorProto.DataType; import ai.onnxruntime.OrtSession.Result; import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode; import ai.onnxruntime.OrtSession.SessionOptions.OptLevel; -import java.io.BufferedInputStream; -import java.io.BufferedReader; import java.io.File; -import java.io.FileInputStream; -import java.io.FileReader; import java.io.IOException; -import java.io.InputStream; -import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; @@ -44,7 +36,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; -import java.util.logging.Level; import java.util.logging.Logger; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -57,76 +48,65 @@ /** Tests for the onnx-runtime Java interface. */ public class InferenceTest { private static final Logger logger = Logger.getLogger(InferenceTest.class.getName()); - private static final Pattern LOAD_PATTERN = Pattern.compile("[,\\[\\] ]"); private static final String propertiesFile = "Properties.txt"; private static final Pattern inputPBPattern = Pattern.compile("input_*.pb"); private static final Pattern outputPBPattern = Pattern.compile("output_*.pb"); - private static Path getResourcePath(String path) { - return new File(InferenceTest.class.getResource(path).getFile()).toPath(); - } + private static final OrtEnvironment env = OrtEnvironment.getEnvironment(); @Test - public void repeatedCloseTest() throws OrtException { - Logger.getLogger(OrtEnvironment.class.getName()).setLevel(Level.SEVERE); - OrtEnvironment env = OrtEnvironment.getEnvironment("repeatedCloseTest"); - try (OrtEnvironment otherEnv = OrtEnvironment.getEnvironment()) { - assertFalse(otherEnv.isClosed()); - } - assertFalse(env.isClosed()); - env.close(); - assertTrue(env.isClosed()); + public void environmentTest() { + // Checks that the environment instance is the same. + OrtEnvironment otherEnv = OrtEnvironment.getEnvironment(); + assertSame(env, otherEnv); + otherEnv = OrtEnvironment.getEnvironment("test-name"); + assertSame(env, otherEnv); } @Test public void createSessionFromPath() throws OrtException { - String modelPath = getResourcePath("/squeezenet.onnx").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromPath"); - OrtSession.SessionOptions options = new SessionOptions()) { - try (OrtSession session = env.createSession(modelPath, options)) { - assertNotNull(session); - assertEquals(1, session.getNumInputs()); // 1 input node - Map inputInfoList = session.getInputInfo(); - assertNotNull(inputInfoList); - assertEquals(1, inputInfoList.size()); - NodeInfo input = inputInfoList.get("data_0"); - assertEquals("data_0", input.getName()); // input node name - assertTrue(input.getInfo() instanceof TensorInfo); - TensorInfo inputInfo = (TensorInfo) input.getInfo(); - assertEquals(OnnxJavaType.FLOAT, inputInfo.type); - int[] expectedInputDimensions = new int[] {1, 3, 224, 224}; - assertEquals(expectedInputDimensions.length, inputInfo.shape.length); - for (int i = 0; i < expectedInputDimensions.length; i++) { - assertEquals(expectedInputDimensions[i], inputInfo.shape[i]); - } + String modelPath = TestHelpers.getResourcePath("/squeezenet.onnx").toString(); + try (OrtSession.SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelPath, options)) { + assertNotNull(session); + assertEquals(1, session.getNumInputs()); // 1 input node + Map inputInfoList = session.getInputInfo(); + assertNotNull(inputInfoList); + assertEquals(1, inputInfoList.size()); + NodeInfo input = inputInfoList.get("data_0"); + assertEquals("data_0", input.getName()); // input node name + assertTrue(input.getInfo() instanceof TensorInfo); + TensorInfo inputInfo = (TensorInfo) input.getInfo(); + assertEquals(OnnxJavaType.FLOAT, inputInfo.type); + int[] expectedInputDimensions = new int[] {1, 3, 224, 224}; + assertEquals(expectedInputDimensions.length, inputInfo.shape.length); + for (int i = 0; i < expectedInputDimensions.length; i++) { + assertEquals(expectedInputDimensions[i], inputInfo.shape[i]); + } - assertEquals(1, session.getNumOutputs()); // 1 output node - Map outputInfoList = session.getOutputInfo(); - assertNotNull(outputInfoList); - assertEquals(1, outputInfoList.size()); - NodeInfo output = outputInfoList.get("softmaxout_1"); - assertEquals("softmaxout_1", output.getName()); // output node name - assertTrue(output.getInfo() instanceof TensorInfo); - TensorInfo outputInfo = (TensorInfo) output.getInfo(); - assertEquals(OnnxJavaType.FLOAT, outputInfo.type); - int[] expectedOutputDimensions = new int[] {1, 1000, 1, 1}; - assertEquals(expectedOutputDimensions.length, outputInfo.shape.length); - for (int i = 0; i < expectedOutputDimensions.length; i++) { - assertEquals(expectedOutputDimensions[i], outputInfo.shape[i]); - } + assertEquals(1, session.getNumOutputs()); // 1 output node + Map outputInfoList = session.getOutputInfo(); + assertNotNull(outputInfoList); + assertEquals(1, outputInfoList.size()); + NodeInfo output = outputInfoList.get("softmaxout_1"); + assertEquals("softmaxout_1", output.getName()); // output node name + assertTrue(output.getInfo() instanceof TensorInfo); + TensorInfo outputInfo = (TensorInfo) output.getInfo(); + assertEquals(OnnxJavaType.FLOAT, outputInfo.type); + int[] expectedOutputDimensions = new int[] {1, 1000, 1, 1}; + assertEquals(expectedOutputDimensions.length, outputInfo.shape.length); + for (int i = 0; i < expectedOutputDimensions.length; i++) { + assertEquals(expectedOutputDimensions[i], outputInfo.shape[i]); } } } @Test public void morePartialInputsTest() throws OrtException { - String modelPath = getResourcePath("/partial-inputs-test-2.onnx").toString(); - try (OrtEnvironment env = - OrtEnvironment.getEnvironment( - OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs"); - OrtSession.SessionOptions options = new SessionOptions(); + String modelPath = TestHelpers.getResourcePath("/partial-inputs-test-2.onnx").toString(); + try (OrtSession.SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { assertNotNull(session); assertEquals(3, session.getNumInputs()); @@ -209,11 +189,8 @@ public void morePartialInputsTest() throws OrtException { @Test public void partialInputsTest() throws OrtException { - String modelPath = getResourcePath("/partial-inputs-test.onnx").toString(); - try (OrtEnvironment env = - OrtEnvironment.getEnvironment( - OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "partialInputs"); - OrtSession.SessionOptions options = new SessionOptions(); + String modelPath = TestHelpers.getResourcePath("/partial-inputs-test.onnx").toString(); + try (OrtSession.SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { assertNotNull(session); assertEquals(3, session.getNumInputs()); @@ -347,122 +324,50 @@ public void partialInputsTest() throws OrtException { @Test public void createSessionFromByteArray() throws IOException, OrtException { - Path modelPath = getResourcePath("/squeezenet.onnx"); + Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); byte[] modelBytes = Files.readAllBytes(modelPath); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("createSessionFromByteArray"); - OrtSession.SessionOptions options = new SessionOptions()) { - try (OrtSession session = env.createSession(modelBytes, options)) { - assertNotNull(session); - assertEquals(1, session.getNumInputs()); // 1 input node - Map inputInfoList = session.getInputInfo(); - assertNotNull(inputInfoList); - assertEquals(1, inputInfoList.size()); - NodeInfo input = inputInfoList.get("data_0"); - assertEquals("data_0", input.getName()); // input node name - assertTrue(input.getInfo() instanceof TensorInfo); - TensorInfo inputInfo = (TensorInfo) input.getInfo(); - assertEquals(OnnxJavaType.FLOAT, inputInfo.type); - int[] expectedInputDimensions = new int[] {1, 3, 224, 224}; - assertEquals(expectedInputDimensions.length, inputInfo.shape.length); - for (int i = 0; i < expectedInputDimensions.length; i++) { - assertEquals(expectedInputDimensions[i], inputInfo.shape[i]); - } - - assertEquals(1, session.getNumOutputs()); // 1 output node - Map outputInfoList = session.getOutputInfo(); - assertNotNull(outputInfoList); - assertEquals(1, outputInfoList.size()); - NodeInfo output = outputInfoList.get("softmaxout_1"); - assertEquals("softmaxout_1", output.getName()); // output node name - assertTrue(output.getInfo() instanceof TensorInfo); - TensorInfo outputInfo = (TensorInfo) output.getInfo(); - assertEquals(OnnxJavaType.FLOAT, outputInfo.type); - int[] expectedOutputDimensions = new int[] {1, 1000, 1, 1}; - assertEquals(expectedOutputDimensions.length, outputInfo.shape.length); - for (int i = 0; i < expectedOutputDimensions.length; i++) { - assertEquals(expectedOutputDimensions[i], outputInfo.shape[i]); - } - - // Check the metadata can be extracted - OnnxModelMetadata metadata = session.getMetadata(); - assertEquals("onnx-caffe2", metadata.getProducerName()); - assertEquals("squeezenet_old", metadata.getGraphName()); - assertEquals("", metadata.getDomain()); - assertEquals("", metadata.getDescription()); - assertEquals(0x7FFFFFFFFFFFFFFFL, metadata.getVersion()); - assertTrue(metadata.getCustomMetadata().isEmpty()); - } - } - } - - @Test - public void environmentThreadPoolTest() throws OrtException { - Path squeezeNet = getResourcePath("/squeezenet.onnx"); - String modelPath = squeezeNet.toString(); - float[] inputData = loadTensorFromFile(getResourcePath("/bench.in")); - float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out")); - Map container = new HashMap<>(); - - OrtEnvironment.ThreadingOptions threadOpts = new OrtEnvironment.ThreadingOptions(); - threadOpts.setGlobalInterOpNumThreads(2); - threadOpts.setGlobalIntraOpNumThreads(2); - threadOpts.setGlobalDenormalAsZero(); - threadOpts.setGlobalSpinControl(true); - try (OrtEnvironment env = - OrtEnvironment.getEnvironment( - OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "environmentThreadPoolTest", threadOpts); - OrtSession.SessionOptions options = new SessionOptions(); - OrtSession.SessionOptions disableThreadOptions = new SessionOptions()) { - disableThreadOptions.disablePerSessionThreads(); - - // Check that the regular session executes - try (OrtSession session = env.createSession(modelPath, options)) { - NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); - long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape; - Object tensorData = OrtUtil.reshape(inputData, inputShape); - OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData); - container.put(inputMeta.getName(), tensor); - try (OrtSession.Result result = session.run(container)) { - OnnxValue resultTensor = result.get(0); - float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); - assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-6f); - } - container.clear(); - tensor.close(); + try (OrtSession.SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelBytes, options)) { + assertNotNull(session); + assertEquals(1, session.getNumInputs()); // 1 input node + Map inputInfoList = session.getInputInfo(); + assertNotNull(inputInfoList); + assertEquals(1, inputInfoList.size()); + NodeInfo input = inputInfoList.get("data_0"); + assertEquals("data_0", input.getName()); // input node name + assertTrue(input.getInfo() instanceof TensorInfo); + TensorInfo inputInfo = (TensorInfo) input.getInfo(); + assertEquals(OnnxJavaType.FLOAT, inputInfo.type); + int[] expectedInputDimensions = new int[] {1, 3, 224, 224}; + assertEquals(expectedInputDimensions.length, inputInfo.shape.length); + for (int i = 0; i < expectedInputDimensions.length; i++) { + assertEquals(expectedInputDimensions[i], inputInfo.shape[i]); } - // Check that the session using the env thread pool executes - try (OrtSession session = env.createSession(modelPath, disableThreadOptions)) { - NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); - long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape; - Object tensorData = OrtUtil.reshape(inputData, inputShape); - OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData); - container.put(inputMeta.getName(), tensor); - try (OrtSession.Result result = session.run(container)) { - OnnxValue resultTensor = result.get(0); - float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); - assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-6f); - } - container.clear(); - tensor.close(); + assertEquals(1, session.getNumOutputs()); // 1 output node + Map outputInfoList = session.getOutputInfo(); + assertNotNull(outputInfoList); + assertEquals(1, outputInfoList.size()); + NodeInfo output = outputInfoList.get("softmaxout_1"); + assertEquals("softmaxout_1", output.getName()); // output node name + assertTrue(output.getInfo() instanceof TensorInfo); + TensorInfo outputInfo = (TensorInfo) output.getInfo(); + assertEquals(OnnxJavaType.FLOAT, outputInfo.type); + int[] expectedOutputDimensions = new int[] {1, 1000, 1, 1}; + assertEquals(expectedOutputDimensions.length, outputInfo.shape.length); + for (int i = 0; i < expectedOutputDimensions.length; i++) { + assertEquals(expectedOutputDimensions[i], outputInfo.shape[i]); } - } - try (OrtEnvironment env = OrtEnvironment.getEnvironment("test")) { - try { - OrtEnvironment newEnv = - OrtEnvironment.getEnvironment( - OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "fail", threadOpts); - // fail as we can't recreate environments with different threading options - fail("Should have thrown IllegalStateException"); - } catch (IllegalStateException e) { - // pass - } + // Check the metadata can be extracted + OnnxModelMetadata metadata = session.getMetadata(); + assertEquals("onnx-caffe2", metadata.getProducerName()); + assertEquals("squeezenet_old", metadata.getGraphName()); + assertEquals("", metadata.getDomain()); + assertEquals("", metadata.getDescription()); + assertEquals(0x7FFFFFFFFFFFFFFFL, metadata.getVersion()); + assertTrue(metadata.getCustomMetadata().isEmpty()); } - - threadOpts.close(); } @Test @@ -475,11 +380,10 @@ public void inferenceTest() throws OrtException { private void canRunInferenceOnAModel(OptLevel graphOptimizationLevel, ExecutionMode exectionMode) throws OrtException { - String modelPath = getResourcePath("/squeezenet.onnx").toString(); + String modelPath = TestHelpers.getResourcePath("/squeezenet.onnx").toString(); // Set the graph optimization level for this session. - try (OrtEnvironment env = OrtEnvironment.getEnvironment("canRunInferenceOnAModel"); - SessionOptions options = new SessionOptions()) { + try (SessionOptions options = new SessionOptions()) { options.setOptimizationLevel(graphOptimizationLevel); options.setExecutionMode(exectionMode); @@ -488,7 +392,8 @@ private void canRunInferenceOnAModel(OptLevel graphOptimizationLevel, ExecutionM Map container = new HashMap<>(); NodeInfo inputMeta = inputMetaMap.values().iterator().next(); - float[] inputData = loadTensorFromFile(getResourcePath("/bench.in")); + float[] inputData = + TestHelpers.loadTensorFromFile(TestHelpers.getResourcePath("/bench.in")); // this is the data for only one input tensor for this model Object tensorData = OrtUtil.reshape(inputData, ((TensorInfo) inputMeta.getInfo()).getShape()); @@ -499,7 +404,8 @@ private void canRunInferenceOnAModel(OptLevel graphOptimizationLevel, ExecutionM try (OrtSession.Result results = session.run(container)) { assertEquals(1, results.size()); - float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out")); + float[] expectedOutput = + TestHelpers.loadTensorFromFile(TestHelpers.getResourcePath("/bench.expected_out")); // validate the results // Only iterates once for (Map.Entry r : results) { @@ -529,8 +435,7 @@ private void canRunInferenceOnAModel(OptLevel graphOptimizationLevel, ExecutionM @Test public void throwWrongInputName() throws OrtException { SqueezeNetTuple tuple = openSessionSqueezeNet(); - try (OrtEnvironment env = tuple.env; - OrtSession session = tuple.session) { + try (OrtSession session = tuple.session) { float[] inputData = tuple.inputData; NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); Map container = new HashMap<>(); @@ -556,8 +461,7 @@ public void throwWrongInputName() throws OrtException { @Test public void throwWrongInputType() throws OrtException { SqueezeNetTuple tuple = openSessionSqueezeNet(); - try (OrtEnvironment env = tuple.env; - OrtSession session = tuple.session) { + try (OrtSession session = tuple.session) { float[] inputData = tuple.inputData; NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); @@ -584,8 +488,7 @@ public void throwWrongInputType() throws OrtException { @Test public void throwExtraInputs() throws OrtException { SqueezeNetTuple tuple = openSessionSqueezeNet(); - try (OrtEnvironment env = tuple.env; - OrtSession session = tuple.session) { + try (OrtSession session = tuple.session) { float[] inputData = tuple.inputData; NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); @@ -615,8 +518,7 @@ public void testMultiThreads() throws OrtException, InterruptedException { int numThreads = 10; int loop = 10; SqueezeNetTuple tuple = openSessionSqueezeNet(); - try (OrtEnvironment env = tuple.env; - OrtSession session = tuple.session) { + try (OrtSession session = tuple.session) { float[] inputData = tuple.inputData; float[] expectedOutput = tuple.outputData; @@ -664,25 +566,23 @@ public void testProviders() { @Test public void testSymbolicDimensionAssignment() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/capi_symbolic_dims.onnx").toString(); + String modelPath = TestHelpers.getResourcePath("/capi_symbolic_dims.onnx").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testSymbolicDimensionAssignment")) { - // Check the dimension is symbolic - try (SessionOptions options = new SessionOptions()) { - try (OrtSession session = env.createSession(modelPath, options)) { - Map infoMap = session.getInputInfo(); - TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); - assertArrayEquals(new long[] {-1, 2}, aInfo.shape); - } + // Check the dimension is symbolic + try (SessionOptions options = new SessionOptions()) { + try (OrtSession session = env.createSession(modelPath, options)) { + Map infoMap = session.getInputInfo(); + TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); + assertArrayEquals(new long[] {-1, 2}, aInfo.shape); } - // Check that when the options are assigned it overrides the symbolic dimension - try (SessionOptions options = new SessionOptions()) { - options.setSymbolicDimensionValue("n", 5); - try (OrtSession session = env.createSession(modelPath, options)) { - Map infoMap = session.getInputInfo(); - TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); - assertArrayEquals(new long[] {5, 2}, aInfo.shape); - } + } + // Check that when the options are assigned it overrides the symbolic dimension + try (SessionOptions options = new SessionOptions()) { + options.setSymbolicDimensionValue("n", 5); + try (OrtSession session = env.createSession(modelPath, options)) { + Map infoMap = session.getInputInfo(); + TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); + assertArrayEquals(new long[] {5, 2}, aInfo.shape); } } } @@ -723,8 +623,7 @@ private void runProvider(OrtProvider provider) throws OrtException { assertTrue(providers.contains(OrtProvider.CPU)); assertTrue(providers.contains(provider)); SqueezeNetTuple tuple = openSessionSqueezeNet(EnumSet.of(provider)); - try (OrtEnvironment env = tuple.env; - OrtSession session = tuple.session) { + try (OrtSession session = tuple.session) { float[] inputData = tuple.inputData; float[] expectedOutput = tuple.outputData; NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); @@ -853,8 +752,7 @@ public void testPreTrainedModel(String opset, String modelName) throws IOExcepti + modelNamesList); } - try (OrtEnvironment env = OrtEnvironment.getEnvironment(); - OrtSession session = env.createSession(onnxModelFileName)) { + try (OrtSession session = env.createSession(onnxModelFileName)) { String testDataDirNamePattern; if (opset.equals("opset9") && modelName.equals("LSTM_Seq_lens_unpacked")) { testDataDirNamePattern = "seq_lens"; // discrepency in data directory @@ -869,12 +767,12 @@ public void testPreTrainedModel(String opset, String modelName) throws IOExcepti Map outputContainer = new HashMap<>(); for (File f : testDataDir.listFiles((dir, name) -> inputPBPattern.matcher(name).matches())) { - StringTensorPair o = loadTensorFromFilePb(env, f, inMeta); + TestHelpers.StringTensorPair o = TestHelpers.loadTensorFromFilePb(env, f, inMeta); inputContainer.put(o.string, o.tensor); } for (File f : testDataDir.listFiles((dir, name) -> outputPBPattern.matcher(name).matches())) { - StringTensorPair o = loadTensorFromFilePb(env, f, outMeta); + TestHelpers.StringTensorPair o = TestHelpers.loadTensorFromFilePb(env, f, outMeta); outputContainer.put(o.string, o.tensor); } @@ -915,10 +813,9 @@ public void testPreTrainedModel(String opset, String modelName) throws IOExcepti @Test public void testModelInputFLOAT() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_FLOAT.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_FLOAT.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputFLOAT"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { String inputName = session.getInputNames().iterator().next(); long[] shape = new long[] {1, 5}; @@ -942,10 +839,9 @@ public void testModelInputFLOAT() throws OrtException { @Test public void testModelInputBuffer() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_FLOAT.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_FLOAT.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputFLOAT"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { String inputName = session.getInputNames().iterator().next(); long[] shape = new long[] {1, 5}; @@ -1003,10 +899,9 @@ public void testModelInputBuffer() throws OrtException { @Test public void testRunOptions() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_BOOL.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_BOOL.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testRunOptions"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options); OrtSession.RunOptions runOptions = new OrtSession.RunOptions()) { runOptions.setRunTag("monkeys"); @@ -1035,73 +930,71 @@ public void testRunOptions() throws OrtException { @Test public void testExtraSessionOptions() throws OrtException, IOException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_BOOL.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_BOOL.pb").toString(); File tmpPath = File.createTempFile("onnx-runtime-profiling", "file"); tmpPath.deleteOnExit(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testExtraSessionOptions")) { - try (SessionOptions options = new SessionOptions()) { - options.setCPUArenaAllocator(true); - options.setMemoryPatternOptimization(true); - options.enableProfiling(tmpPath.getAbsolutePath()); - options.setLoggerId("monkeys"); - options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL); - options.setSessionLogVerbosityLevel(5); - Map configEntries = options.getConfigEntries(); - assertTrue(configEntries.isEmpty()); - options.addConfigEntry("key", "value"); - assertEquals("value", configEntries.get("key")); - try { - options.addConfigEntry("", "invalid key"); - fail("Add config entry with empty key should have failed"); - } catch (OrtException e) { - assertTrue(e.getMessage().contains("Config key is empty")); - assertEquals(OrtException.OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + try (SessionOptions options = new SessionOptions()) { + options.setCPUArenaAllocator(true); + options.setMemoryPatternOptimization(true); + options.enableProfiling(tmpPath.getAbsolutePath()); + options.setLoggerId("monkeys"); + options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL); + options.setSessionLogVerbosityLevel(5); + Map configEntries = options.getConfigEntries(); + assertTrue(configEntries.isEmpty()); + options.addConfigEntry("key", "value"); + assertEquals("value", configEntries.get("key")); + try { + options.addConfigEntry("", "invalid key"); + fail("Add config entry with empty key should have failed"); + } catch (OrtException e) { + assertTrue(e.getMessage().contains("Config key is empty")); + assertEquals(OrtException.OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + try (OrtSession session = env.createSession(modelPath, options)) { + String inputName = session.getInputNames().iterator().next(); + Map container = new HashMap<>(); + boolean[] flatInput = new boolean[] {true, false, true, false, true}; + Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5}); + OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn); + container.put(inputName, ov); + try (OrtSession.Result res = session.run(container)) { + boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); + assertArrayEquals(flatInput, resultArray); } - try (OrtSession session = env.createSession(modelPath, options)) { - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - boolean[] flatInput = new boolean[] {true, false, true, false, true}; - Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5}); - OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn); - container.put(inputName, ov); - try (OrtSession.Result res = session.run(container)) { - boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); - assertArrayEquals(flatInput, resultArray); - } - // Check that the profiling start time doesn't throw - long profilingStartTime = session.getProfilingStartTimeInNs(); + // Check that the profiling start time doesn't throw + long profilingStartTime = session.getProfilingStartTimeInNs(); - // Check the profiling output doesn't throw - String profilingOutput = session.endProfiling(); - File profilingOutputFile = new File(profilingOutput); - profilingOutputFile.deleteOnExit(); - try (OrtSession.Result res = session.run(container)) { - boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); - assertArrayEquals(flatInput, resultArray); - } - OnnxValue.close(container); + // Check the profiling output doesn't throw + String profilingOutput = session.endProfiling(); + File profilingOutputFile = new File(profilingOutput); + profilingOutputFile.deleteOnExit(); + try (OrtSession.Result res = session.run(container)) { + boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); + assertArrayEquals(flatInput, resultArray); } + OnnxValue.close(container); } - try (SessionOptions options = new SessionOptions()) { - options.setCPUArenaAllocator(false); - options.setMemoryPatternOptimization(false); - options.enableProfiling(tmpPath.getAbsolutePath()); - options.disableProfiling(); - options.setSessionLogVerbosityLevel(0); - try (OrtSession session = env.createSession(modelPath, options)) { - String inputName = session.getInputNames().iterator().next(); - Map container = new HashMap<>(); - boolean[] flatInput = new boolean[] {true, false, true, false, true}; - Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5}); - OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn); - container.put(inputName, ov); - try (OrtSession.Result res = session.run(container)) { - boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); - assertArrayEquals(flatInput, resultArray); - } - OnnxValue.close(container); + } + try (SessionOptions options = new SessionOptions()) { + options.setCPUArenaAllocator(false); + options.setMemoryPatternOptimization(false); + options.enableProfiling(tmpPath.getAbsolutePath()); + options.disableProfiling(); + options.setSessionLogVerbosityLevel(0); + try (OrtSession session = env.createSession(modelPath, options)) { + String inputName = session.getInputNames().iterator().next(); + Map container = new HashMap<>(); + boolean[] flatInput = new boolean[] {true, false, true, false, true}; + Object tensorIn = OrtUtil.reshape(flatInput, new long[] {1, 5}); + OnnxTensor ov = OnnxTensor.createTensor(env, tensorIn); + container.put(inputName, ov); + try (OrtSession.Result res = session.run(container)) { + boolean[] resultArray = TestHelpers.flattenBoolean(res.get(0).getValue()); + assertArrayEquals(flatInput, resultArray); } + OnnxValue.close(container); } } } @@ -1115,19 +1008,18 @@ public void testLoadCustomLibrary() throws OrtException { if (osName.contains("windows")) { // In windows we start in the wrong working directory relative to the custom_op_library.dll // So we look it up as a classpath resource and resolve it to a real path - customLibraryName = getResourcePath("/custom_op_library.dll").toString(); + customLibraryName = TestHelpers.getResourcePath("/custom_op_library.dll").toString(); } else if (osName.contains("mac")) { - customLibraryName = getResourcePath("/libcustom_op_library.dylib").toString(); + customLibraryName = TestHelpers.getResourcePath("/libcustom_op_library.dylib").toString(); } else if (osName.contains("linux")) { - customLibraryName = getResourcePath("/libcustom_op_library.so").toString(); + customLibraryName = TestHelpers.getResourcePath("/libcustom_op_library.so").toString(); } else { fail("Unknown os/platform '" + osName + "'"); } String customOpLibraryTestModel = - getResourcePath("/custom_op_library/custom_op_test.onnx").toString(); + TestHelpers.getResourcePath("/custom_op_library/custom_op_test.onnx").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testLoadCustomLibrary"); - SessionOptions options = new SessionOptions()) { + try (SessionOptions options = new SessionOptions()) { options.registerCustomOpLibrary(customLibraryName); if (OnnxRuntime.extractCUDA()) { options.addCUDA(); @@ -1170,41 +1062,39 @@ public void testLoadCustomLibrary() throws OrtException { @Test public void testModelMetadata() throws OrtException { - String modelPath = getResourcePath("/model_with_valid_ort_config_json.onnx").toString(); + String modelPath = + TestHelpers.getResourcePath("/model_with_valid_ort_config_json.onnx").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelMetadata")) { - try (OrtSession session = env.createSession(modelPath)) { - OnnxModelMetadata modelMetadata = session.getMetadata(); + try (OrtSession session = env.createSession(modelPath)) { + OnnxModelMetadata modelMetadata = session.getMetadata(); - Assertions.assertEquals(1, modelMetadata.getVersion()); + Assertions.assertEquals(1, modelMetadata.getVersion()); - Assertions.assertEquals("Hari", modelMetadata.getProducerName()); + Assertions.assertEquals("Hari", modelMetadata.getProducerName()); - Assertions.assertEquals("matmul test", modelMetadata.getGraphName()); + Assertions.assertEquals("matmul test", modelMetadata.getGraphName()); - Assertions.assertEquals("", modelMetadata.getDomain()); + Assertions.assertEquals("", modelMetadata.getDomain()); - Assertions.assertEquals( - "This is a test model with a valid ORT config Json", modelMetadata.getDescription()); + Assertions.assertEquals( + "This is a test model with a valid ORT config Json", modelMetadata.getDescription()); - Assertions.assertEquals("graph description", modelMetadata.getGraphDescription()); + Assertions.assertEquals("graph description", modelMetadata.getGraphDescription()); - Assertions.assertEquals(2, modelMetadata.getCustomMetadata().size()); - Assertions.assertEquals("dummy_value", modelMetadata.getCustomMetadata().get("dummy_key")); - Assertions.assertEquals( - "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}", - modelMetadata.getCustomMetadata().get("ort_config")); - } + Assertions.assertEquals(2, modelMetadata.getCustomMetadata().size()); + Assertions.assertEquals("dummy_value", modelMetadata.getCustomMetadata().get("dummy_key")); + Assertions.assertEquals( + "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}", + modelMetadata.getCustomMetadata().get("ort_config")); } } @Test public void testModelInputBOOL() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_BOOL.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_BOOL.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputBOOL"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { String inputName = session.getInputNames().iterator().next(); Map container = new HashMap<>(); @@ -1223,10 +1113,9 @@ public void testModelInputBOOL() throws OrtException { @Test public void testModelInputINT32() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_INT32.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_INT32.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT32"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { String inputName = session.getInputNames().iterator().next(); Map container = new HashMap<>(); @@ -1245,10 +1134,9 @@ public void testModelInputINT32() throws OrtException { @Test public void testModelInputDOUBLE() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_DOUBLE.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_DOUBLE.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputDOUBLE"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { String inputName = session.getInputNames().iterator().next(); Map container = new HashMap<>(); @@ -1267,10 +1155,9 @@ public void testModelInputDOUBLE() throws OrtException { @Test public void testModelInputINT8() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_INT8.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_INT8.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT8"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { String inputName = session.getInputNames().iterator().next(); Map container = new HashMap<>(); @@ -1288,10 +1175,9 @@ public void testModelInputINT8() throws OrtException { @Test public void testModelInputUINT8() throws OrtException { - String modelPath = getResourcePath("/test_types_UINT8.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_UINT8.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputUINT8"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { String inputName = session.getInputNames().iterator().next(); Map container = new HashMap<>(); @@ -1311,10 +1197,9 @@ public void testModelInputUINT8() throws OrtException { @Test public void testModelInputINT16() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_INT16.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_INT16.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT16"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { String inputName = session.getInputNames().iterator().next(); Map container = new HashMap<>(); @@ -1333,10 +1218,9 @@ public void testModelInputINT16() throws OrtException { @Test public void testModelInputINT64() throws OrtException { // model takes 1x5 input of fixed type, echoes back - String modelPath = getResourcePath("/test_types_INT64.pb").toString(); + String modelPath = TestHelpers.getResourcePath("/test_types_INT64.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelInputINT64"); - SessionOptions options = new SessionOptions(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { String inputName = session.getInputNames().iterator().next(); Map container = new HashMap<>(); @@ -1360,9 +1244,8 @@ public void testModelSequenceOfMapIntFloat() throws OrtException { // "probabilities" is a sequence> // https://github.com/onnx/sklearn-onnx/blob/master/docs/examples/plot_pipeline_lightgbm.py - String modelPath = getResourcePath("/test_sequence_map_int_float.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelSequenceOfMapIntFloat"); - SessionOptions options = new SessionOptions(); + String modelPath = TestHelpers.getResourcePath("/test_sequence_map_int_float.pb").toString(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { Map outputInfos = session.getOutputInfo(); @@ -1427,9 +1310,8 @@ public void testModelSequenceOfMapStringFloat() throws OrtException { // "label" is a tensor, // "probabilities" is a sequence> // https://github.com/onnx/sklearn-onnx/blob/master/docs/examples/plot_pipeline_lightgbm.py - String modelPath = getResourcePath("/test_sequence_map_string_float.pb").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelSequenceOfMapStringFloat"); - SessionOptions options = new SessionOptions(); + String modelPath = TestHelpers.getResourcePath("/test_sequence_map_string_float.pb").toString(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { Map outputInfos = session.getOutputInfo(); @@ -1491,14 +1373,13 @@ public void testModelSequenceOfMapStringFloat() throws OrtException { @Test public void testModelSerialization() throws OrtException, IOException { String cwd = System.getProperty("user.dir"); - Path squeezeNet = getResourcePath("/squeezenet.onnx"); + Path squeezeNet = TestHelpers.getResourcePath("/squeezenet.onnx"); String modelPath = squeezeNet.toString(); File tmpFile = File.createTempFile("optimized-squeezenet", ".onnx"); String modelOutputPath = tmpFile.getAbsolutePath(); Assertions.assertEquals(0, tmpFile.length()); - try (OrtEnvironment env = OrtEnvironment.getEnvironment()) { - // Set the optimized model file path to assert that no exception are thrown. - SessionOptions options = new SessionOptions(); + // Set the optimized model file path to assert that no exception are thrown. + try (SessionOptions options = new SessionOptions()) { options.setOptimizedModelFilePath(modelOutputPath); options.setOptimizationLevel(OptLevel.BASIC_OPT); try (OrtSession session = env.createSession(modelPath, options)) { @@ -1513,9 +1394,8 @@ public void testModelSerialization() throws OrtException, IOException { @Test public void testStringIdentity() throws OrtException { - String modelPath = getResourcePath("/identity_string.onnx").toString(); - try (OrtEnvironment env = OrtEnvironment.getEnvironment("testStringIdentity"); - SessionOptions options = new SessionOptions(); + String modelPath = TestHelpers.getResourcePath("/identity_string.onnx").toString(); + try (SessionOptions options = new SessionOptions(); OrtSession session = env.createSession(modelPath, options)) { Map outputInfos = session.getOutputInfo(); @@ -1576,14 +1456,11 @@ public void testStringIdentity() throws OrtException { /** Carrier tuple for the squeeze net model. */ private static class SqueezeNetTuple { - public final OrtEnvironment env; public final OrtSession session; public final float[] inputData; public final float[] outputData; - public SqueezeNetTuple( - OrtEnvironment env, OrtSession session, float[] inputData, float[] outputData) { - this.env = env; + public SqueezeNetTuple(OrtSession session, float[] inputData, float[] outputData) { this.session = session; this.inputData = inputData; this.outputData = outputData; @@ -1603,9 +1480,8 @@ private static SqueezeNetTuple openSessionSqueezeNet() throws OrtException { */ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet providers) throws OrtException { - Path squeezeNet = getResourcePath("/squeezenet.onnx"); + Path squeezeNet = TestHelpers.getResourcePath("/squeezenet.onnx"); String modelPath = squeezeNet.toString(); - OrtEnvironment env = OrtEnvironment.getEnvironment(); SessionOptions options = new SessionOptions(); for (OrtProvider p : providers) { switch (p) { @@ -1650,179 +1526,9 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet provid } } OrtSession session = env.createSession(modelPath, options); - float[] inputData = loadTensorFromFile(getResourcePath("/bench.in")); - float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out")); - return new SqueezeNetTuple(env, session, inputData, expectedOutput); - } - - private static float[] loadTensorFromFile(Path filename) { - return loadTensorFromFile(filename, true); - } - - private static float[] loadTensorFromFile(Path filename, boolean skipHeader) { - // read data from file - try (BufferedReader reader = new BufferedReader(new FileReader(filename.toFile()))) { - if (skipHeader) { - reader.readLine(); // skip the input name - } - String[] dataStr = LOAD_PATTERN.split(reader.readLine()); - List tensorData = new ArrayList<>(); - for (int i = 0; i < dataStr.length; i++) { - if (!dataStr[i].isEmpty()) { - tensorData.add(Float.parseFloat(dataStr[i])); - } - } - return TestHelpers.toPrimitiveFloat(tensorData); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - private static class TypeWidth { - public final OnnxJavaType type; - public final int width; - - public TypeWidth(OnnxJavaType type, int width) { - this.type = type; - this.width = width; - } - } - - private static TypeWidth getTypeAndWidth(TensorProto.DataType elemType) { - OnnxJavaType type; - int width; - switch (elemType) { - case FLOAT: - type = OnnxJavaType.FLOAT; - width = 4; - break; - case UINT8: - case INT8: - type = OnnxJavaType.INT8; - width = 1; - break; - case UINT16: - case INT16: - type = OnnxJavaType.INT16; - width = 2; - break; - case INT32: - case UINT32: - type = OnnxJavaType.INT32; - width = 4; - break; - case INT64: - case UINT64: - type = OnnxJavaType.INT64; - width = 8; - break; - case STRING: - type = OnnxJavaType.STRING; - width = 1; - break; - case BOOL: - type = OnnxJavaType.BOOL; - width = 1; - break; - case FLOAT16: - type = OnnxJavaType.FLOAT; - width = 2; - break; - case DOUBLE: - type = OnnxJavaType.DOUBLE; - width = 8; - break; - default: - type = null; - width = 0; - break; - } - return new TypeWidth(type, width); - } - - private static StringTensorPair loadTensorFromFilePb( - OrtEnvironment env, File filename, Map nodeMetaDict) - throws IOException, OrtException { - InputStream is = new BufferedInputStream(new FileInputStream(filename), 1024 * 1024 * 4); - OnnxMl.TensorProto tensor = OnnxMl.TensorProto.parseFrom(is); - is.close(); - - TypeWidth tw = getTypeAndWidth(DataType.forNumber(tensor.getDataType())); - int width = tw.width; - OnnxJavaType tensorElemType = tw.type; - long[] intDims = new long[tensor.getDimsCount()]; - for (int i = 0; i < tensor.getDimsCount(); i++) { - intDims[i] = tensor.getDims(i); - } - - TensorInfo nodeMeta = null; - String nodeName = ""; - if (nodeMetaDict.size() == 1) { - for (Map.Entry e : nodeMetaDict.entrySet()) { - nodeMeta = (TensorInfo) e.getValue().getInfo(); - nodeName = e.getKey(); // valid for single node input - } - } else if (nodeMetaDict.size() > 1) { - if (!tensor.getName().isEmpty()) { - nodeMeta = (TensorInfo) nodeMetaDict.get(tensor.getName()).getInfo(); - nodeName = tensor.getName(); - } else { - boolean matchfound = false; - // try to find from matching type and shape - for (Map.Entry e : nodeMetaDict.entrySet()) { - if (e.getValue().getInfo() instanceof TensorInfo) { - TensorInfo meta = (TensorInfo) e.getValue().getInfo(); - if (tensorElemType == meta.type && tensor.getDimsCount() == meta.shape.length) { - int i = 0; - for (; i < meta.shape.length; i++) { - if (meta.shape[i] != -1 && meta.shape[i] != intDims[i]) { - break; - } - } - if (i >= meta.shape.length) { - matchfound = true; - nodeMeta = meta; - nodeName = e.getKey(); - break; - } - } - } - } - if (!matchfound) { - // throw error - throw new IllegalStateException( - "No matching Tensor found in InputOutputMetadata corresponding to the serialized tensor loaded from " - + filename); - } - } - } else { - // throw error - throw new IllegalStateException( - "While reading the serialized tensor loaded from " - + filename - + ", metaDataDict has 0 elements"); - } - - Assertions.assertEquals(tensorElemType, nodeMeta.type); - Assertions.assertEquals(nodeMeta.shape.length, tensor.getDimsCount()); - for (int i = 0; i < nodeMeta.shape.length; i++) { - Assertions.assertTrue((nodeMeta.shape[i] == -1) || (nodeMeta.shape[i] == intDims[i])); - } - - ByteBuffer buffer = ByteBuffer.wrap(tensor.getRawData().toByteArray()); - - OnnxTensor onnxTensor = OnnxTensor.createTensor(env, buffer, intDims, tensorElemType); - - return new StringTensorPair(nodeName, onnxTensor); - } - - private static class StringTensorPair { - public final String string; - public final OnnxTensor tensor; - - public StringTensorPair(String string, OnnxTensor tensor) { - this.string = string; - this.tensor = tensor; - } + float[] inputData = TestHelpers.loadTensorFromFile(TestHelpers.getResourcePath("/bench.in")); + float[] expectedOutput = + TestHelpers.loadTensorFromFile(TestHelpers.getResourcePath("/bench.expected_out")); + return new SqueezeNetTuple(session, inputData, expectedOutput); } } diff --git a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java index 1b5eafd354c20..681179beff9c9 100644 --- a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java +++ b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java @@ -12,117 +12,114 @@ public class TensorCreationTest { @Test public void testScalarCreation() throws OrtException { - try (OrtEnvironment env = OrtEnvironment.getEnvironment()) { - String[] stringValues = new String[] {"true", "false"}; - for (String s : stringValues) { - try (OnnxTensor t = OnnxTensor.createTensor(env, s)) { - Assertions.assertEquals(s, t.getValue()); - } + OrtEnvironment env = OrtEnvironment.getEnvironment(); + String[] stringValues = new String[] {"true", "false"}; + for (String s : stringValues) { + try (OnnxTensor t = OnnxTensor.createTensor(env, s)) { + Assertions.assertEquals(s, t.getValue()); } + } - boolean[] boolValues = new boolean[] {true, false}; - for (boolean b : boolValues) { - try (OnnxTensor t = OnnxTensor.createTensor(env, b)) { - Assertions.assertEquals(b, t.getValue()); - } + boolean[] boolValues = new boolean[] {true, false}; + for (boolean b : boolValues) { + try (OnnxTensor t = OnnxTensor.createTensor(env, b)) { + Assertions.assertEquals(b, t.getValue()); } + } - int[] intValues = - new int[] {-1, 0, 1, 12345678, -12345678, Integer.MAX_VALUE, Integer.MIN_VALUE}; - for (int i : intValues) { - try (OnnxTensor t = OnnxTensor.createTensor(env, i)) { - Assertions.assertEquals(i, t.getValue()); - } + int[] intValues = + new int[] {-1, 0, 1, 12345678, -12345678, Integer.MAX_VALUE, Integer.MIN_VALUE}; + for (int i : intValues) { + try (OnnxTensor t = OnnxTensor.createTensor(env, i)) { + Assertions.assertEquals(i, t.getValue()); } + } - long[] longValues = - new long[] {-1L, 0L, 1L, 12345678L, -12345678L, Long.MAX_VALUE, Long.MIN_VALUE}; - for (long l : longValues) { - try (OnnxTensor t = OnnxTensor.createTensor(env, l)) { - Assertions.assertEquals(l, t.getValue()); - } + long[] longValues = + new long[] {-1L, 0L, 1L, 12345678L, -12345678L, Long.MAX_VALUE, Long.MIN_VALUE}; + for (long l : longValues) { + try (OnnxTensor t = OnnxTensor.createTensor(env, l)) { + Assertions.assertEquals(l, t.getValue()); } + } - float[] floatValues = - new float[] { - -1.0f, - 0.0f, - -0.0f, - 1.0f, - 1234.5678f, - -1234.5678f, - (float) Math.PI, - (float) Math.E, - Float.MAX_VALUE, - Float.MIN_VALUE - }; - for (float f : floatValues) { - try (OnnxTensor t = OnnxTensor.createTensor(env, f)) { - Assertions.assertEquals(f, t.getValue()); - } + float[] floatValues = + new float[] { + -1.0f, + 0.0f, + -0.0f, + 1.0f, + 1234.5678f, + -1234.5678f, + (float) Math.PI, + (float) Math.E, + Float.MAX_VALUE, + Float.MIN_VALUE + }; + for (float f : floatValues) { + try (OnnxTensor t = OnnxTensor.createTensor(env, f)) { + Assertions.assertEquals(f, t.getValue()); } + } - double[] doubleValues = - new double[] { - -1.0, - 0.0, - -0.0, - 1.0, - 1234.5678, - -1234.5678, - Math.PI, - Math.E, - Double.MAX_VALUE, - Double.MIN_VALUE - }; - for (double d : doubleValues) { - try (OnnxTensor t = OnnxTensor.createTensor(env, d)) { - Assertions.assertEquals(d, t.getValue()); - } + double[] doubleValues = + new double[] { + -1.0, + 0.0, + -0.0, + 1.0, + 1234.5678, + -1234.5678, + Math.PI, + Math.E, + Double.MAX_VALUE, + Double.MIN_VALUE + }; + for (double d : doubleValues) { + try (OnnxTensor t = OnnxTensor.createTensor(env, d)) { + Assertions.assertEquals(d, t.getValue()); } } } @Test public void testStringCreation() throws OrtException { - try (OrtEnvironment env = OrtEnvironment.getEnvironment()) { - String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"}; - try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { - Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape); - String[] output = (String[]) t.getValue(); - Assertions.assertArrayEquals(arrValues, output); - } + OrtEnvironment env = OrtEnvironment.getEnvironment(); + String[] arrValues = new String[] {"this", "is", "a", "single", "dimensional", "string"}; + try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { + Assertions.assertArrayEquals(new long[] {6}, t.getInfo().shape); + String[] output = (String[]) t.getValue(); + Assertions.assertArrayEquals(arrValues, output); + } - String[][] stringValues = - new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}}; - try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) { - Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape); - String[][] output = (String[][]) t.getValue(); - Assertions.assertArrayEquals(stringValues, output); - } + String[][] stringValues = + new String[][] {{"this", "is", "a"}, {"multi", "dimensional", "string"}}; + try (OnnxTensor t = OnnxTensor.createTensor(env, stringValues)) { + Assertions.assertArrayEquals(new long[] {2, 3}, t.getInfo().shape); + String[][] output = (String[][]) t.getValue(); + Assertions.assertArrayEquals(stringValues, output); + } - String[][][] deepStringValues = - new String[][][] { - {{"this", "is", "a"}, {"multi", "dimensional", "string"}}, - {{"with", "lots", "more"}, {"dimensions", "than", "before"}} - }; - try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) { - Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape); - String[][][] output = (String[][][]) t.getValue(); - Assertions.assertArrayEquals(deepStringValues, output); - } + String[][][] deepStringValues = + new String[][][] { + {{"this", "is", "a"}, {"multi", "dimensional", "string"}}, + {{"with", "lots", "more"}, {"dimensions", "than", "before"}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, deepStringValues)) { + Assertions.assertArrayEquals(new long[] {2, 2, 3}, t.getInfo().shape); + String[][][] output = (String[][][]) t.getValue(); + Assertions.assertArrayEquals(deepStringValues, output); } } @Test public void testUint8Creation() throws OrtException { - try (OrtEnvironment env = OrtEnvironment.getEnvironment()) { - byte[] buf = new byte[] {0, 1}; - ByteBuffer data = ByteBuffer.wrap(buf); - long[] shape = new long[] {2}; - try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) { - Assertions.assertArrayEquals(buf, (byte[]) t.getValue()); - } + OrtEnvironment env = OrtEnvironment.getEnvironment(); + byte[] buf = new byte[] {0, 1}; + ByteBuffer data = ByteBuffer.wrap(buf); + long[] shape = new long[] {2}; + try (OnnxTensor t = OnnxTensor.createTensor(env, data, shape, OnnxJavaType.UINT8)) { + Assertions.assertArrayEquals(buf, (byte[]) t.getValue()); } } } diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java index d255f04ecb66a..43fef5eb1e122 100644 --- a/java/src/test/java/ai/onnxruntime/TestHelpers.java +++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java @@ -1,16 +1,31 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; +import java.io.BufferedInputStream; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import org.junit.jupiter.api.Assertions; /** Test helpers for manipulating primitive arrays. */ class TestHelpers { + private static final Pattern LOAD_PATTERN = Pattern.compile("[,\\[\\] ]"); + static boolean[] toPrimitiveBoolean(List input) { boolean[] output = new boolean[input.size()]; @@ -234,4 +249,179 @@ static void flattenDoubleBase(double[] input, List output) { static void flattenStringBase(String[] input, List output) { output.addAll(Arrays.asList(input)); } + + static Path getResourcePath(String path) { + return new File(InferenceTest.class.getResource(path).getFile()).toPath(); + } + + static float[] loadTensorFromFile(Path filename) { + return loadTensorFromFile(filename, true); + } + + static float[] loadTensorFromFile(Path filename, boolean skipHeader) { + // read data from file + try (BufferedReader reader = new BufferedReader(new FileReader(filename.toFile()))) { + if (skipHeader) { + reader.readLine(); // skip the input name + } + String[] dataStr = LOAD_PATTERN.split(reader.readLine()); + List tensorData = new ArrayList<>(); + for (int i = 0; i < dataStr.length; i++) { + if (!dataStr[i].isEmpty()) { + tensorData.add(Float.parseFloat(dataStr[i])); + } + } + return toPrimitiveFloat(tensorData); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private static TypeWidth getTypeAndWidth(OnnxMl.TensorProto.DataType elemType) { + OnnxJavaType type; + int width; + switch (elemType) { + case FLOAT: + type = OnnxJavaType.FLOAT; + width = 4; + break; + case UINT8: + case INT8: + type = OnnxJavaType.INT8; + width = 1; + break; + case UINT16: + case INT16: + type = OnnxJavaType.INT16; + width = 2; + break; + case INT32: + case UINT32: + type = OnnxJavaType.INT32; + width = 4; + break; + case INT64: + case UINT64: + type = OnnxJavaType.INT64; + width = 8; + break; + case STRING: + type = OnnxJavaType.STRING; + width = 1; + break; + case BOOL: + type = OnnxJavaType.BOOL; + width = 1; + break; + case FLOAT16: + type = OnnxJavaType.FLOAT; + width = 2; + break; + case DOUBLE: + type = OnnxJavaType.DOUBLE; + width = 8; + break; + default: + type = null; + width = 0; + break; + } + return new TypeWidth(type, width); + } + + static StringTensorPair loadTensorFromFilePb( + OrtEnvironment env, File filename, Map nodeMetaDict) + throws IOException, OrtException { + InputStream is = new BufferedInputStream(new FileInputStream(filename), 1024 * 1024 * 4); + OnnxMl.TensorProto tensor = OnnxMl.TensorProto.parseFrom(is); + is.close(); + + TypeWidth tw = getTypeAndWidth(OnnxMl.TensorProto.DataType.forNumber(tensor.getDataType())); + int width = tw.width; + OnnxJavaType tensorElemType = tw.type; + long[] intDims = new long[tensor.getDimsCount()]; + for (int i = 0; i < tensor.getDimsCount(); i++) { + intDims[i] = tensor.getDims(i); + } + + TensorInfo nodeMeta = null; + String nodeName = ""; + if (nodeMetaDict.size() == 1) { + for (Map.Entry e : nodeMetaDict.entrySet()) { + nodeMeta = (TensorInfo) e.getValue().getInfo(); + nodeName = e.getKey(); // valid for single node input + } + } else if (nodeMetaDict.size() > 1) { + if (!tensor.getName().isEmpty()) { + nodeMeta = (TensorInfo) nodeMetaDict.get(tensor.getName()).getInfo(); + nodeName = tensor.getName(); + } else { + boolean matchfound = false; + // try to find from matching type and shape + for (Map.Entry e : nodeMetaDict.entrySet()) { + if (e.getValue().getInfo() instanceof TensorInfo) { + TensorInfo meta = (TensorInfo) e.getValue().getInfo(); + if (tensorElemType == meta.type && tensor.getDimsCount() == meta.shape.length) { + int i = 0; + for (; i < meta.shape.length; i++) { + if (meta.shape[i] != -1 && meta.shape[i] != intDims[i]) { + break; + } + } + if (i >= meta.shape.length) { + matchfound = true; + nodeMeta = meta; + nodeName = e.getKey(); + break; + } + } + } + } + if (!matchfound) { + // throw error + throw new IllegalStateException( + "No matching Tensor found in InputOutputMetadata corresponding to the serialized tensor loaded from " + + filename); + } + } + } else { + // throw error + throw new IllegalStateException( + "While reading the serialized tensor loaded from " + + filename + + ", metaDataDict has 0 elements"); + } + + Assertions.assertEquals(tensorElemType, nodeMeta.type); + Assertions.assertEquals(nodeMeta.shape.length, tensor.getDimsCount()); + for (int i = 0; i < nodeMeta.shape.length; i++) { + Assertions.assertTrue((nodeMeta.shape[i] == -1) || (nodeMeta.shape[i] == intDims[i])); + } + + ByteBuffer buffer = ByteBuffer.wrap(tensor.getRawData().toByteArray()); + + OnnxTensor onnxTensor = OnnxTensor.createTensor(env, buffer, intDims, tensorElemType); + + return new StringTensorPair(nodeName, onnxTensor); + } + + private static class TypeWidth { + public final OnnxJavaType type; + public final int width; + + public TypeWidth(OnnxJavaType type, int width) { + this.type = type; + this.width = width; + } + } + + static class StringTensorPair { + public final String string; + public final OnnxTensor tensor; + + public StringTensorPair(String string, OnnxTensor tensor) { + this.string = string; + this.tensor = tensor; + } + } } diff --git a/java/src/test/java/sample/ScoreMNIST.java b/java/src/test/java/sample/ScoreMNIST.java index 5ad40a5bec4ed..5587b58e17f52 100644 --- a/java/src/test/java/sample/ScoreMNIST.java +++ b/java/src/test/java/sample/ScoreMNIST.java @@ -263,8 +263,8 @@ public static void main(String[] args) throws OrtException, IOException { return; } - try (OrtEnvironment env = OrtEnvironment.getEnvironment(); - OrtSession.SessionOptions opts = new SessionOptions()) { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + try (OrtSession.SessionOptions opts = new SessionOptions()) { opts.setOptimizationLevel(OptLevel.BASIC_OPT);