Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[java] Changes OrtEnvironment so it can't be closed by users #10670

Merged
merged 2 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ test {
java {
dependsOn spotlessJava
}
forkEvery 1 // Forces each test class to be run in a separate JVM, which is necessary for testing the environment thread pool
useJUnitPlatform()
if (cmakeBuildDir != null) {
workingDir cmakeBuildDir
Expand Down
16 changes: 8 additions & 8 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, Object data) throws Or
*/
static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Object data)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
TensorInfo info = TensorInfo.constructFromJavaArray(data);
if (info.type == OnnxJavaType.STRING) {
if (info.shape.length == 0) {
Expand Down Expand Up @@ -403,7 +403,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, String[] data, long[]
*/
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, String[] data, long[] shape) throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
TensorInfo info =
new TensorInfo(
shape,
Expand Down Expand Up @@ -451,7 +451,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, FloatBuffer data, long
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, FloatBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.FLOAT;
return createTensor(type, allocator, data, shape);
} else {
Expand Down Expand Up @@ -492,7 +492,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, DoubleBuffer data, lon
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, DoubleBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.DOUBLE;
return createTensor(type, allocator, data, shape);
} else {
Expand Down Expand Up @@ -571,7 +571,7 @@ public static OnnxTensor createTensor(
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, ByteBuffer data, long[] shape, OnnxJavaType type)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
return createTensor(type, allocator, data, shape);
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
Expand Down Expand Up @@ -611,7 +611,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT16;
return createTensor(type, allocator, data, shape);
} else {
Expand Down Expand Up @@ -652,7 +652,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, IntBuffer data, long[]
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, IntBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT32;
return createTensor(type, allocator, data, shape);
} else {
Expand Down Expand Up @@ -693,7 +693,7 @@ public static OnnxTensor createTensor(OrtEnvironment env, LongBuffer data, long[
static OnnxTensor createTensor(
OrtEnvironment env, OrtAllocator allocator, LongBuffer data, long[] shape)
throws OrtException {
if ((!env.isClosed()) && (!allocator.isClosed())) {
if (!allocator.isClosed()) {
OnnxJavaType type = OnnxJavaType.INT64;
return createTensor(type, allocator, data, shape);
} else {
Expand Down
95 changes: 44 additions & 51 deletions java/src/main/java/ai/onnxruntime/OrtEnvironment.java
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
/*
* Copyright (c) 2019-2021 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import ai.onnxruntime.OrtSession.SessionOptions;
import java.io.IOException;
import java.util.EnumSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.logging.Logger;

/**
* The host object for the onnx-runtime system. Can create {@link OrtSession}s which encapsulate
* specific models.
*
* <p>There can be at most one OrtEnvironment object created in a JVM lifetime. This class
* implements {@link AutoCloseable} as before for backwards compatibility with 1.10 and earlier, but
* the {@link #close} method is a no-op. The environment is closed by a JVM shutdown hook registered
* on construction.
*/
public class OrtEnvironment implements AutoCloseable {
public final class OrtEnvironment implements AutoCloseable {

private static final Logger logger = Logger.getLogger(OrtEnvironment.class.getName());

Expand All @@ -30,8 +34,6 @@ public class OrtEnvironment implements AutoCloseable {

private static volatile OrtEnvironment INSTANCE;

private static final AtomicInteger refCount = new AtomicInteger();

private static volatile OrtLoggingLevel curLogLevel;

private static volatile String curLoggingName;
Expand All @@ -47,8 +49,6 @@ public static synchronized OrtEnvironment getEnvironment() {
// If there's no instance, create one.
return getEnvironment(OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING, DEFAULT_NAME);
} else {
// else return the current one.
refCount.incrementAndGet();
return INSTANCE;
}
}
Expand Down Expand Up @@ -106,7 +106,6 @@ public static synchronized OrtEnvironment getEnvironment(
"Tried to change OrtEnvironment's logging level or name while a reference exists.");
}
}
refCount.incrementAndGet();
return INSTANCE;
}

Expand All @@ -131,7 +130,6 @@ public static synchronized OrtEnvironment getEnvironment(
} catch (OrtException e) {
throw new IllegalStateException("Failed to create OrtEnvironment", e);
}
refCount.incrementAndGet();
return INSTANCE;
} else {
// As the thread pool state is unknown, and that's probably not what the user wanted.
Expand All @@ -144,8 +142,6 @@ public static synchronized OrtEnvironment getEnvironment(

final OrtAllocator defaultAllocator;

private volatile boolean closed = false;

/**
* Create an OrtEnvironment using a default name.
*
Expand All @@ -165,6 +161,8 @@ private OrtEnvironment() throws OrtException {
private OrtEnvironment(OrtLoggingLevel loggingLevel, String name) throws OrtException {
nativeHandle = createHandle(OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name);
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
Runtime.getRuntime()
.addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
}

/**
Expand All @@ -181,6 +179,8 @@ private OrtEnvironment(OrtLoggingLevel loggingLevel, String name, ThreadingOptio
createHandle(
OnnxRuntime.ortApiHandle, loggingLevel.getValue(), name, threadOptions.nativeHandle);
defaultAllocator = new OrtAllocator(getDefaultAllocator(OnnxRuntime.ortApiHandle), true);
Runtime.getRuntime()
.addShutdownHook(new Thread(new OrtEnvCloser(OnnxRuntime.ortApiHandle, nativeHandle)));
}

/**
Expand Down Expand Up @@ -219,11 +219,7 @@ public OrtSession createSession(String modelPath, SessionOptions options) throws
*/
OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOptions options)
throws OrtException {
if (!closed) {
return new OrtSession(this, modelPath, allocator, options);
} else {
throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
}
return new OrtSession(this, modelPath, allocator, options);
}

/**
Expand Down Expand Up @@ -262,11 +258,7 @@ public OrtSession createSession(byte[] modelArray) throws OrtException {
*/
OrtSession createSession(byte[] modelArray, OrtAllocator allocator, SessionOptions options)
throws OrtException {
if (!closed) {
return new OrtSession(this, modelArray, allocator, options);
} else {
throw new IllegalStateException("Trying to create an OrtSession on a closed OrtEnvironment.");
}
return new OrtSession(this, modelArray, allocator, options);
}

/**
Expand All @@ -279,41 +271,11 @@ public void setTelemetry(boolean sendTelemetry) throws OrtException {
setTelemetry(OnnxRuntime.ortApiHandle, nativeHandle, sendTelemetry);
}

/**
* Is this environment closed?
*
* @return True if the environment is closed.
*/
public boolean isClosed() {
return closed;
}

@Override
public String toString() {
return "OrtEnvironment(name=" + curLoggingName + ",logLevel=" + curLogLevel + ")";
}

/**
* Closes the OrtEnvironment. If this is the last reference to the environment then it closes the
* native handle.
*
* @throws OrtException If the close failed.
*/
@Override
public synchronized void close() throws OrtException {
synchronized (refCount) {
int curCount = refCount.get();
if (curCount != 0) {
refCount.decrementAndGet();
}
if (curCount == 1) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
INSTANCE = null;
}
}
}

/**
* Gets the providers available in this environment.
*
Expand Down Expand Up @@ -377,11 +339,22 @@ private static native long createHandle(
private static native void setTelemetry(long apiHandle, long nativeHandle, boolean sendTelemetry)
throws OrtException;

/** Close is a no-op on OrtEnvironment since ORT 1.11. */
@Override
public void close() {}

/**
* Controls the global thread pools in the environment. Only used if the session is constructed
* using an options with {@link OrtSession.SessionOptions#disablePerSessionThreads()} set.
*/
public static final class ThreadingOptions implements AutoCloseable {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}

private final long nativeHandle;

Expand Down Expand Up @@ -486,4 +459,24 @@ private native void setGlobalDenormalAsZero(long apiHandle, long nativeHandle)

private native void closeThreadingOptions(long apiHandle, long nativeHandle);
}

private static final class OrtEnvCloser implements Runnable {

private final long apiHandle;
private final long nativeHandle;

OrtEnvCloser(long apiHandle, long nativeHandle) {
this.apiHandle = apiHandle;
this.nativeHandle = nativeHandle;
}

@Override
public void run() {
try {
OrtEnvironment.close(apiHandle, nativeHandle);
} catch (OrtException e) {
System.err.println("Error closing OrtEnvironment, " + e);
}
}
}
}
88 changes: 88 additions & 0 deletions java/src/test/java/ai/onnxruntime/EnvironmentThreadPoolTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import static ai.onnxruntime.TestHelpers.getResourcePath;
import static ai.onnxruntime.TestHelpers.loadTensorFromFile;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

import java.nio.file.Path;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;

/** This test is in a separate class to ensure it is run in a clean JVM. */
public class EnvironmentThreadPoolTest {

@Test
public void environmentThreadPoolTest() throws OrtException {
Path squeezeNet = getResourcePath("/squeezenet.onnx");
String modelPath = squeezeNet.toString();
float[] inputData = loadTensorFromFile(getResourcePath("/bench.in"));
float[] expectedOutput = loadTensorFromFile(getResourcePath("/bench.expected_out"));
Map<String, OnnxTensor> container = new HashMap<>();

OrtEnvironment.ThreadingOptions threadOpts = new OrtEnvironment.ThreadingOptions();
threadOpts.setGlobalInterOpNumThreads(2);
threadOpts.setGlobalIntraOpNumThreads(2);
threadOpts.setGlobalDenormalAsZero();
threadOpts.setGlobalSpinControl(true);
OrtEnvironment env =
OrtEnvironment.getEnvironment(
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "environmentThreadPoolTest", threadOpts);
try (OrtSession.SessionOptions options = new OrtSession.SessionOptions();
OrtSession.SessionOptions disableThreadOptions = new OrtSession.SessionOptions()) {
disableThreadOptions.disablePerSessionThreads();

// Check that the regular session executes
try (OrtSession session = env.createSession(modelPath, options)) {
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
Object tensorData = OrtUtil.reshape(inputData, inputShape);
OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
container.put(inputMeta.getName(), tensor);
try (OrtSession.Result result = session.run(container)) {
OnnxValue resultTensor = result.get(0);
float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
assertEquals(expectedOutput.length, resultArray.length);
assertArrayEquals(expectedOutput, resultArray, 1e-6f);
}
container.clear();
tensor.close();
}

// Check that the session using the env thread pool executes
try (OrtSession session = env.createSession(modelPath, disableThreadOptions)) {
NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape;
Object tensorData = OrtUtil.reshape(inputData, inputShape);
OnnxTensor tensor = OnnxTensor.createTensor(env, tensorData);
container.put(inputMeta.getName(), tensor);
try (OrtSession.Result result = session.run(container)) {
OnnxValue resultTensor = result.get(0);
float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
assertEquals(expectedOutput.length, resultArray.length);
assertArrayEquals(expectedOutput, resultArray, 1e-6f);
}
container.clear();
tensor.close();
}
}

try {
OrtEnvironment newEnv =
OrtEnvironment.getEnvironment(
OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL, "fail", threadOpts);
// fail as we can't recreate environments with different threading options
fail("Should have thrown IllegalStateException");
} catch (IllegalStateException e) {
// pass
}

threadOpts.close();
}
}
Loading