Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add available providers method #68

Merged
merged 1 commit into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/Api.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package com.jyuzawa.onnxruntime;

import java.util.Set;

/**
* The top-level API of the ONNX runtime.
*
Expand All @@ -15,4 +17,6 @@ public interface Api {
* @return a builder
*/
Environment.Builder newEnvironment();

Set<String> getAvailableProviders();
}
32 changes: 32 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.jyuzawa.onnxruntime_extern.OrtApi.EnableTelemetryEvents;
import com.jyuzawa.onnxruntime_extern.OrtApi.FillStringTensor;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetAllocatorWithDefaultOptions;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetAvailableProviders;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetDimensions;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetDimensionsCount;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetErrorCode;
Expand All @@ -60,6 +61,7 @@
import com.jyuzawa.onnxruntime_extern.OrtApi.ModelMetadataGetProducerName;
import com.jyuzawa.onnxruntime_extern.OrtApi.ModelMetadataGetVersion;
import com.jyuzawa.onnxruntime_extern.OrtApi.ModelMetadataLookupCustomMetadataMap;
import com.jyuzawa.onnxruntime_extern.OrtApi.ReleaseAvailableProviders;
import com.jyuzawa.onnxruntime_extern.OrtApi.ReleaseEnv;
import com.jyuzawa.onnxruntime_extern.OrtApi.ReleaseMemoryInfo;
import com.jyuzawa.onnxruntime_extern.OrtApi.ReleaseModelMetadata;
Expand Down Expand Up @@ -103,6 +105,9 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.function.Function;

final class ApiImpl implements Api {
Expand Down Expand Up @@ -134,6 +139,7 @@ final class ApiImpl implements Api {
final EnableTelemetryEvents EnableTelemetryEvents;
final FillStringTensor FillStringTensor;
final GetAllocatorWithDefaultOptions GetAllocatorWithDefaultOptions;
final GetAvailableProviders GetAvailableProviders;
final GetDimensions GetDimensions;
final GetDimensionsCount GetDimensionsCount;
final GetErrorCode GetErrorCode;
Expand All @@ -158,6 +164,7 @@ final class ApiImpl implements Api {
final ModelMetadataGetProducerName ModelMetadataGetProducerName;
final ModelMetadataGetVersion ModelMetadataGetVersion;
final ModelMetadataLookupCustomMetadataMap ModelMetadataLookupCustomMetadataMap;
final ReleaseAvailableProviders ReleaseAvailableProviders;
final ReleaseEnv ReleaseEnv;
final ReleaseMemoryInfo ReleaseMemoryInfo;
final ReleaseModelMetadata ReleaseModelMetadata;
Expand Down Expand Up @@ -197,6 +204,8 @@ final class ApiImpl implements Api {
final SessionGetOverridableInitializerName SessionGetOverridableInitializerName;
final SessionGetOverridableInitializerTypeInfo SessionGetOverridableInitializerTypeInfo;

private final Set<String> providers;

ApiImpl(MemorySegment segment) {
MemorySession scope = MemorySession.global();
this.AddRunConfigEntry = OrtApi.AddRunConfigEntry(segment, scope);
Expand Down Expand Up @@ -227,6 +236,7 @@ final class ApiImpl implements Api {
this.EnableProfiling = OrtApi.EnableProfiling(segment, scope);
this.FillStringTensor = OrtApi.FillStringTensor(segment, scope);
this.GetAllocatorWithDefaultOptions = OrtApi.GetAllocatorWithDefaultOptions(segment, scope);
this.GetAvailableProviders = OrtApi.GetAvailableProviders(segment, scope);
this.GetDimensions = OrtApi.GetDimensions(segment, scope);
this.GetDimensionsCount = OrtApi.GetDimensionsCount(segment, scope);
this.GetErrorCode = OrtApi.GetErrorCode(segment, scope);
Expand All @@ -251,6 +261,7 @@ final class ApiImpl implements Api {
this.ModelMetadataGetProducerName = OrtApi.ModelMetadataGetProducerName(segment, scope);
this.ModelMetadataGetVersion = OrtApi.ModelMetadataGetVersion(segment, scope);
this.ModelMetadataLookupCustomMetadataMap = OrtApi.ModelMetadataLookupCustomMetadataMap(segment, scope);
this.ReleaseAvailableProviders = OrtApi.ReleaseAvailableProviders(segment, scope);
this.ReleaseEnv = OrtApi.ReleaseEnv(segment, scope);
this.ReleaseMemoryInfo = OrtApi.ReleaseMemoryInfo(segment, scope);
this.ReleaseModelMetadata = OrtApi.ReleaseModelMetadata(segment, scope);
Expand Down Expand Up @@ -289,13 +300,34 @@ final class ApiImpl implements Api {
this.SessionGetOverridableInitializerCount = OrtApi.SessionGetOverridableInitializerCount(segment, scope);
this.SessionGetOverridableInitializerName = OrtApi.SessionGetOverridableInitializerName(segment, scope);
this.SessionGetOverridableInitializerTypeInfo = OrtApi.SessionGetOverridableInitializerTypeInfo(segment, scope);

try (MemorySession session = MemorySession.openConfined()) {
Set<String> providers = new LinkedHashSet<>();
MemorySegment pointer = scope.allocate(C_POINTER);
MemorySegment countPointer = scope.allocate(C_INT);
checkStatus(GetAvailableProviders.apply(pointer.address(), countPointer.address()));
int numProviders = countPointer.getAtIndex(C_INT, 0);
MemorySegment providersArray = MemorySegment.ofAddress(
pointer.getAtIndex(C_POINTER, 0), numProviders * C_POINTER.byteSize(), session);
for (int i = 0; i < numProviders; i++) {
MemoryAddress providerAddress = providersArray.getAtIndex(C_POINTER, i);
providers.add(providerAddress.getUtf8String(0));
}
checkStatus(ReleaseAvailableProviders.apply(providersArray.address(), numProviders));
this.providers = Collections.unmodifiableSet(providers);
}
}

@Override
public Environment.Builder newEnvironment() {
return new EnvironmentBuilderImpl(this);
}

@Override
public Set<String> getAvailableProviders() {
return providers;
}

void checkStatus(Addressable rawAddress) {
MemoryAddress status = rawAddress.address();
if (MemoryAddress.NULL.equals(status)) {
Expand Down
11 changes: 10 additions & 1 deletion src/test/java/com/jyuzawa/onnxruntime/SessionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Collections;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import onnx.OnnxMl.AttributeProto;
import onnx.OnnxMl.AttributeProto.AttributeType;
import onnx.OnnxMl.GraphProto;
Expand All @@ -40,12 +41,13 @@
public class SessionTest {

private static final Charset UTF8 = Charset.forName("utf-8");
private static Api api;
private static Environment environment;

@BeforeClass
public static void setup() {
OnnxRuntime apiBase = OnnxRuntime.get();
Api api = apiBase.getApi();
api = apiBase.getApi();
environment = api.newEnvironment()
.setLogSeverityLevel(OnnxRuntimeLoggingLevel.VERBOSE)
.setLogId("testing")
Expand Down Expand Up @@ -73,6 +75,13 @@ private ByteBuffer identityModel(TypeProto type) {
.asReadOnlyByteBuffer();
}

@Test
public void providersTest() {
Set<String> providers = api.getAvailableProviders();
assertFalse(providers.isEmpty());
assertTrue(providers.contains("CPUExecutionProvider"));
}

@Test
public void infoTest() throws IOException {
TypeProto type = TypeProto.newBuilder()
Expand Down