Skip to content

Commit

Permalink
add available providers method
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzawa-san committed Nov 19, 2022
1 parent 7af46e9 commit 541ff4d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/Api.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package com.jyuzawa.onnxruntime;

import java.util.Set;

/**
* The top-level API of the ONNX runtime.
*
Expand All @@ -15,4 +17,6 @@ public interface Api {
* @return a builder
*/
Environment.Builder newEnvironment();

Set<String> getAvailableProviders();
}
32 changes: 32 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import com.jyuzawa.onnxruntime_extern.OrtApi.EnableTelemetryEvents;
import com.jyuzawa.onnxruntime_extern.OrtApi.FillStringTensor;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetAllocatorWithDefaultOptions;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetAvailableProviders;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetDimensions;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetDimensionsCount;
import com.jyuzawa.onnxruntime_extern.OrtApi.GetErrorCode;
Expand All @@ -60,6 +61,7 @@
import com.jyuzawa.onnxruntime_extern.OrtApi.ModelMetadataGetProducerName;
import com.jyuzawa.onnxruntime_extern.OrtApi.ModelMetadataGetVersion;
import com.jyuzawa.onnxruntime_extern.OrtApi.ModelMetadataLookupCustomMetadataMap;
import com.jyuzawa.onnxruntime_extern.OrtApi.ReleaseAvailableProviders;
import com.jyuzawa.onnxruntime_extern.OrtApi.ReleaseEnv;
import com.jyuzawa.onnxruntime_extern.OrtApi.ReleaseMemoryInfo;
import com.jyuzawa.onnxruntime_extern.OrtApi.ReleaseModelMetadata;
Expand Down Expand Up @@ -103,6 +105,9 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.function.Function;

final class ApiImpl implements Api {
Expand Down Expand Up @@ -134,6 +139,7 @@ final class ApiImpl implements Api {
final EnableTelemetryEvents EnableTelemetryEvents;
final FillStringTensor FillStringTensor;
final GetAllocatorWithDefaultOptions GetAllocatorWithDefaultOptions;
final GetAvailableProviders GetAvailableProviders;
final GetDimensions GetDimensions;
final GetDimensionsCount GetDimensionsCount;
final GetErrorCode GetErrorCode;
Expand All @@ -158,6 +164,7 @@ final class ApiImpl implements Api {
final ModelMetadataGetProducerName ModelMetadataGetProducerName;
final ModelMetadataGetVersion ModelMetadataGetVersion;
final ModelMetadataLookupCustomMetadataMap ModelMetadataLookupCustomMetadataMap;
final ReleaseAvailableProviders ReleaseAvailableProviders;
final ReleaseEnv ReleaseEnv;
final ReleaseMemoryInfo ReleaseMemoryInfo;
final ReleaseModelMetadata ReleaseModelMetadata;
Expand Down Expand Up @@ -197,6 +204,8 @@ final class ApiImpl implements Api {
final SessionGetOverridableInitializerName SessionGetOverridableInitializerName;
final SessionGetOverridableInitializerTypeInfo SessionGetOverridableInitializerTypeInfo;

private final Set<String> providers;

ApiImpl(MemorySegment segment) {
MemorySession scope = MemorySession.global();
this.AddRunConfigEntry = OrtApi.AddRunConfigEntry(segment, scope);
Expand Down Expand Up @@ -227,6 +236,7 @@ final class ApiImpl implements Api {
this.EnableProfiling = OrtApi.EnableProfiling(segment, scope);
this.FillStringTensor = OrtApi.FillStringTensor(segment, scope);
this.GetAllocatorWithDefaultOptions = OrtApi.GetAllocatorWithDefaultOptions(segment, scope);
this.GetAvailableProviders = OrtApi.GetAvailableProviders(segment, scope);
this.GetDimensions = OrtApi.GetDimensions(segment, scope);
this.GetDimensionsCount = OrtApi.GetDimensionsCount(segment, scope);
this.GetErrorCode = OrtApi.GetErrorCode(segment, scope);
Expand All @@ -251,6 +261,7 @@ final class ApiImpl implements Api {
this.ModelMetadataGetProducerName = OrtApi.ModelMetadataGetProducerName(segment, scope);
this.ModelMetadataGetVersion = OrtApi.ModelMetadataGetVersion(segment, scope);
this.ModelMetadataLookupCustomMetadataMap = OrtApi.ModelMetadataLookupCustomMetadataMap(segment, scope);
this.ReleaseAvailableProviders = OrtApi.ReleaseAvailableProviders(segment, scope);
this.ReleaseEnv = OrtApi.ReleaseEnv(segment, scope);
this.ReleaseMemoryInfo = OrtApi.ReleaseMemoryInfo(segment, scope);
this.ReleaseModelMetadata = OrtApi.ReleaseModelMetadata(segment, scope);
Expand Down Expand Up @@ -289,13 +300,34 @@ final class ApiImpl implements Api {
this.SessionGetOverridableInitializerCount = OrtApi.SessionGetOverridableInitializerCount(segment, scope);
this.SessionGetOverridableInitializerName = OrtApi.SessionGetOverridableInitializerName(segment, scope);
this.SessionGetOverridableInitializerTypeInfo = OrtApi.SessionGetOverridableInitializerTypeInfo(segment, scope);

try (MemorySession session = MemorySession.openConfined()) {
Set<String> providers = new LinkedHashSet<>();
MemorySegment pointer = scope.allocate(C_POINTER);
MemorySegment countPointer = scope.allocate(C_INT);
checkStatus(GetAvailableProviders.apply(pointer.address(), countPointer.address()));
int numProviders = countPointer.getAtIndex(C_INT, 0);
MemorySegment providersArray = MemorySegment.ofAddress(
pointer.getAtIndex(C_POINTER, 0), numProviders * C_POINTER.byteSize(), session);
for (int i = 0; i < numProviders; i++) {
MemoryAddress providerAddress = providersArray.getAtIndex(C_POINTER, i);
providers.add(providerAddress.getUtf8String(0));
}
checkStatus(ReleaseAvailableProviders.apply(providersArray.address(), numProviders));
this.providers = Collections.unmodifiableSet(providers);
}
}

@Override
public Environment.Builder newEnvironment() {
return new EnvironmentBuilderImpl(this);
}

@Override
public Set<String> getAvailableProviders() {
return providers;
}

void checkStatus(Addressable rawAddress) {
MemoryAddress status = rawAddress.address();
if (MemoryAddress.NULL.equals(status)) {
Expand Down
11 changes: 10 additions & 1 deletion src/test/java/com/jyuzawa/onnxruntime/SessionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Collections;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import onnx.OnnxMl.AttributeProto;
import onnx.OnnxMl.AttributeProto.AttributeType;
import onnx.OnnxMl.GraphProto;
Expand All @@ -40,12 +41,13 @@
public class SessionTest {

private static final Charset UTF8 = Charset.forName("utf-8");
private static Api api;
private static Environment environment;

@BeforeClass
public static void setup() {
OnnxRuntime apiBase = OnnxRuntime.get();
Api api = apiBase.getApi();
api = apiBase.getApi();
environment = api.newEnvironment()
.setLogSeverityLevel(OnnxRuntimeLoggingLevel.VERBOSE)
.setLogId("testing")
Expand Down Expand Up @@ -73,6 +75,13 @@ private ByteBuffer identityModel(TypeProto type) {
.asReadOnlyByteBuffer();
}

@Test
public void providersTest() {
Set<String> providers = api.getAvailableProviders();
assertFalse(providers.isEmpty());
assertTrue(providers.contains("CPUExecutionProvider"));
}

@Test
public void infoTest() throws IOException {
TypeProto type = TypeProto.newBuilder()
Expand Down

0 comments on commit 541ff4d

Please sign in to comment.