Skip to content

Commit

Permalink
[java] Changes OrtEnvironment so it can't be closed by users (#10670)
Browse files Browse the repository at this point in the history
* Changes OrtEnvironment so it can't be closed by users.

* Fix the formatting and add a same instance check.
  • Loading branch information
Craigacp authored Mar 1, 2022
1 parent e23a224 commit f856608
Show file tree
Hide file tree
Showing 8 changed files with 643 additions and 668 deletions.
1 change: 1 addition & 0 deletions java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ test {
java {
dependsOn spotlessJava
}
forkEvery 1 // Forces each test class to be run in a separate JVM, which is necessary for testing the environment thread pool
useJUnitPlatform()
if (cmakeBuildDir != null) {
workingDir cmakeBuildDir
Expand Down
16 changes: 8 additions & 8 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, Object data) throws Or
*/
static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Object data)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
TensorInfo info = TensorInfo.constructFromJavaArray(data);
if (info.type == OnnxJavaType.STRING) {
if (info.shape.length == 0) {
Expand Down Expand Up @@ -403,7 +403,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, String[] data, long[]
*/
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, String[] data, long[] shape) throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
TensorInfo info =
new TensorInfo(
shape,
Expand Down Expand Up @@ -451,7 +451,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, FloatBuffer data, long
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, FloatBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.FLOAT;
return createTensor(type, allocator, data, shape);
} else {
Expand Down Expand Up @@ -492,7 +492,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, DoubleBuffer data, lon
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, DoubleBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.DOUBLE;
return createTensor(type, allocator, data, shape);
} else {
Expand Down Expand Up @@ -571,7 +571,7 @@ public static OnnxTensor createTensor(
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, ByteBuffer data, long[] shape, OnnxJavaType type)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
Expand Down Expand Up @@ -611,7 +611,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT16;
return createTensor(type, allocator, data, shape);
} else {
Expand Down Expand Up @@ -652,7 +652,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, IntBuffer data, long[]
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, IntBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT32;
return createTensor(type, allocator, data, shape);
} else {
Expand Down Expand Up @@ -693,7 +693,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, LongBuffer data, long[
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, LongBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT64;
return createTensor(type, allocator, data, shape);
} else {
Expand Down
95 changes: 44 additions & 51 deletions java/src/main/java/ai/onnxruntime/OrtEnvironment.java
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
/*
* Copyright (c) 2019-2021 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import ai.onnxruntime.OrtSession.SessionOptions;
import java.io.IOException;
import java.util.EnumSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;

/**
* The host object for the onnx-runtime system. Can create {@link OrtSession}s which encapsulate
* specific models.
*
* <p>There can be at most one OrtEnvironment object created in a JVM lifetime. This class
* implements {@link AutoCloseable} as before for backwards compatibility with 1.10 and earlier, but
* the {@link #close} method is a no-op. The environment is closed by a JVM shutdown hook registered
* on construction.
*/
public class OrtEnvironment implements AutoCloseable {
public final class OrtEnvironment implements AutoCloseable {

private static final Logger logger = Logger.getLogger(OrtEnvironment.class.getName());

Expand All @@ -30,8 +34,6 @@ public class OrtEnvironment implements AutoCloseable {

private static volatile OrtEnvironment INSTANCE;

private static final AtomicInteger refCount = new AtomicInteger();

private static volatile OrtLoggingLevel curLogLevel;

private static volatile String curLoggingName;
Expand All @@ -47,8 +49,6 @@ public static synchronized OrtEnvironment getEnvironment() {
// If there's no instance, create one.
return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME);
} else {
// else return the current one.
refCount.incrementAndGet();
return INSTANCE;
}
}
Expand Down Expand Up @@ -106,7 +106,6 @@ public static synchronized OrtEnvironment getEnvironment(
"Tried to change OrtEnvironment's logging level or name while a reference exists.");
}
}
refCount.incrementAndGet();
return INSTANCE;
}

Expand All @@ -131,7 +130,6 @@ public static synchronized OrtEnvironment getEnvironment(
} catch (OrtException e) {
throw new IllegalStateException("Failed to create OrtEnvironment", e);
}
refCount.incrementAndGet();
return INSTANCE;
} else {
// As the thread pool state is unknown, and that's probably not what the user wanted.
Expand All @@ -144,8 +142,6 @@ public static synchronized OrtEnvironment getEnvironment(

final OrtAllocator defaultAllocator;

private volatile boolean closed = false;

/**
* Create an OrtEnvironment using a default name.
*
Expand All @@ -165,6 +161,8 @@ private OrtEnvironment() throws OrtException {
private OrtEnvironment(OrtLoggingLevel loggingLevel, String name) throws OrtException {
nativeHandle = createHandle(OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name);
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
Runtime.getRuntime()
.addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
}

/**
Expand All @@ -181,6 +179,8 @@ private OrtEnvironment(OrtLoggingLevel loggingLevel, String name, ThreadingOptio
createHandle(
OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name, threadOptions.nativeHandle);
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
Runtime.getRuntime()
.addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
}

/**
Expand Down Expand Up @@ -219,11 +219,7 @@ public OrtSession createSession(String modelPath, SessionOptions options) throws
*/
OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOptions options)
throws OrtException {
if (!closed) {
return new OrtSession(this, modelPath, allocator, options);
} else {
throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
}
return new OrtSession(this, modelPath, allocator, options);
}

/**
Expand Down Expand Up @@ -262,11 +258,7 @@ public OrtSession createSession(byte[] modelArray) throws OrtException {
*/
OrtSession createSession(byte[] modelArray, OrtAllocator allocator, SessionOptions options)
throws OrtException {
if (!closed) {
return new OrtSession(this, modelArray, allocator, options);
} else {
throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
}
return new OrtSession(this, modelArray, allocator, options);
}

/**
Expand All @@ -279,41 +271,11 @@ public void setTelemetry(boolean sendTelemetry) throws OrtException {
setTelemetry(OnnxRuntime.ortApiHandle, nativeHandle, sendTelemetry);
}

/**
* Is this environment closed?
*
* @return True if the environment is closed.
*/
public boolean isClosed() {
return closed;
}

@Override
public String toString() {
return "OrtEnvironment(name=" + curLoggingName + ",logLevel=" + curLogLevel + ")";
}

/**
* Closes the OrtEnvironment. If this is the last reference to the environment then it closes the
* native handle.
*
* @throws OrtException If the close failed.
*/
@Override
public synchronized void close() throws OrtException {
synchronized (refCount) {
int curCount = refCount.get();
if (curCount != 0) {
refCount.decrementAndGet();
}
if (curCount == 1) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
INSTANCE = null;
}
}
}

/**
* Gets the providers available in this environment.
*
Expand Down Expand Up @@ -377,11 +339,22 @@ private static native long createHandle(
private static native void setTelemetry(long apiHandle, long nativeHandle, boolean sendTelemetry)
throws OrtException;

/** Close is a no-op on OrtEnvironment since ORT 1.11. */
@Override
public void close() {}

/**
* Controls the global thread pools in the environment. Only used if the session is constructed
* using an options with {@link OrtSession.SessionOptions#disablePerSessionThreads()} set.
*/
public static final class ThreadingOptions implements AutoCloseable {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}

private final long nativeHandle;

Expand Down Expand Up @@ -486,4 +459,24 @@ private native void setGlobalDenormalAsZero(long apiHandle, long nativeHandle)

private native void closeThreadingOptions(long apiHandle, long nativeHandle);
}

private static final class OrtEnvCloser implements Runnable {

private final long apiHandle;
private final long nativeHandle;

OrtEnvCloser(long apiHandle, long nativeHandle) {
this.apiHandle = apiHandle;
this.nativeHandle = nativeHandle;
}

@Override
public void run() {
try {
OrtEnvironment.close(apiHandle, nativeHandle);
} catch (OrtException e) {
System.err.println("Error closing OrtEnvironment, " + e);
}
}
}
}
88 changes: 88 additions & 0 deletions java/src/test/java/ai/onnxruntime/EnvironmentThreadPoolTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import static ai.onnxruntime.TestHelpers.getResourcePath;
import static ai.onnxruntime.TestHelpers.loadTensorFromFile;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;

/** This test is in a separate class to ensure it is run in a clean JVM. */
public class EnvironmentThreadPoolTest {

@Test
public void environmentThreadPoolTest() throws OrtException {
Path squeezeNet = getResourcePath("/squeezenet.onnx");
String modelPath = squeezeNet.toString();
float[] inputData = loadTensorFromFile(getResourcePath("/bench.in"));
float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out"));
Map<String, OnnxTensor> container = new HashMap<>();

OrtEnvironment.ThreadingOptions threadOpts = new OrtEnvironment.ThreadingOptions();
threadOpts.setGlobalInterOpNumThreads(2);
threadOpts.setGlobalIntraOpNumThreads(2);
threadOpts.setGlobalDenormalAsZero();
threadOpts.setGlobalSpinControl(true);
OrtEnvironment env =
OrtEnvironment.getEnvironment(
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "environmentThreadPoolTest", threadOpts);
try (OrtSession.SessionOptions options = new OrtSession.SessionOptions();
OrtSession.SessionOptions disableThreadOptions = new OrtSession.SessionOptions()) {
disableThreadOptions.disablePerSessionThreads();

// Check that the regular session executes
try (OrtSession session = env.createSession(modelPath, options)) {
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
Object tensorData = OrtUtil.reshape(inputData, inputShape);
OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
container.put(inputMeta.getName(), tensor);
try (OrtSession.Result result = session.run(container)) {
OnnxValue resultTensor = result.get(0);
float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
assertEquals(expectedOutput.length, resultArray.length);
assertArrayEquals(expectedOutput, resultArray, 1e-6f);
}
container.clear();
tensor.close();
}

// Check that the session using the env thread pool executes
try (OrtSession session = env.createSession(modelPath, disableThreadOptions)) {
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
Object tensorData = OrtUtil.reshape(inputData, inputShape);
OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
container.put(inputMeta.getName(), tensor);
try (OrtSession.Result result = session.run(container)) {
OnnxValue resultTensor = result.get(0);
float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
assertEquals(expectedOutput.length, resultArray.length);
assertArrayEquals(expectedOutput, resultArray, 1e-6f);
}
container.clear();
tensor.close();
}
}

try {
OrtEnvironment newEnv =
OrtEnvironment.getEnvironment(
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "fail", threadOpts);
// fail as we can't recreate environments with different threading options
fail("Should have thrown IllegalStateException");
} catch (IllegalStateException e) {
// pass
}

threadOpts.close();
}
}
Loading

0 comments on commit f856608

Please sign in to comment.