Skip to content

Commit

Permalink
[java] Adding the graph description to the exposed model metadata. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp authored Feb 28, 2022
1 parent 037f08f commit e47434e
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 6 deletions.
28 changes: 27 additions & 1 deletion java/src/main/java/ai/onnxruntime/OnnxModelMetadata.java
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import java.util.Collections;
Expand All @@ -17,6 +21,7 @@ public final class OnnxModelMetadata {

private final String producerName;
private final String graphName;
private final String graphDescription;
private final String domain;
private final String description;
private final long version;
Expand All @@ -29,6 +34,7 @@ public final class OnnxModelMetadata {
*
* @param producerName The model producer name.
* @param graphName The model graph name.
* @param graphDescription The model graph description.
* @param domain The model domain name.
* @param description The model description.
* @param version The model version.
Expand All @@ -37,12 +43,14 @@ public final class OnnxModelMetadata {
OnnxModelMetadata(
String producerName,
String graphName,
String graphDescription,
String domain,
String description,
long version,
String[] customMetadataArray) {
this.producerName = producerName == null ? "" : producerName;
this.graphName = graphName == null ? "" : graphName;
this.graphDescription = graphDescription == null ? "" : graphDescription;
this.domain = domain == null ? "" : domain;
this.description = description == null ? "" : description;
this.version = version;
Expand All @@ -66,6 +74,7 @@ public final class OnnxModelMetadata {
*
* @param producerName The model producer name.
* @param graphName The model graph name.
* @param graphDescription The model graph name.
* @param domain The model domain name.
* @param description The model description.
* @param version The model version.
Expand All @@ -74,12 +83,14 @@ public final class OnnxModelMetadata {
OnnxModelMetadata(
String producerName,
String graphName,
String graphDescription,
String domain,
String description,
long version,
Map<String, String> customMetadata) {
this.producerName = producerName == null ? "" : producerName;
this.graphName = graphName == null ? "" : graphName;
this.graphDescription = graphDescription == null ? "" : graphDescription;
this.domain = domain == null ? "" : domain;
this.description = description == null ? "" : description;
this.version = version;
Expand All @@ -94,6 +105,7 @@ public final class OnnxModelMetadata {
public OnnxModelMetadata(OnnxModelMetadata other) {
this.producerName = other.producerName;
this.graphName = other.graphName;
this.graphDescription = other.graphDescription;
this.domain = other.domain;
this.description = other.description;
this.version = other.version;
Expand All @@ -111,14 +123,16 @@ public boolean equals(Object o) {
return version == that.version
&& producerName.equals(that.producerName)
&& graphName.equals(that.graphName)
&& graphDescription.equals(that.graphDescription)
&& domain.equals(that.domain)
&& description.equals(that.description)
&& customMetadata.equals(that.customMetadata);
}

@Override
public int hashCode() {
return Objects.hash(producerName, graphName, domain, description, version, customMetadata);
return Objects.hash(
producerName, graphName, graphDescription, domain, description, version, customMetadata);
}

/**
Expand All @@ -139,6 +153,15 @@ public String getGraphName() {
return graphName;
}

/**
* Gets the graph description.
*
* @return The graph description.
*/
public String getGraphDescription() {
return graphDescription;
}

/**
* Gets the domain.
*
Expand Down Expand Up @@ -195,6 +218,9 @@ public String toString() {
+ ", graphName='"
+ graphName
+ '\''
+ ", graphDescription='"
+ graphDescription
+ '\''
+ ", domain='"
+ domain
+ '\''
Expand Down
13 changes: 9 additions & 4 deletions java/src/main/native/ai_onnxruntime_OrtSession.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand Down Expand Up @@ -385,7 +385,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
jclass metadataClazz = (*jniEnv)->FindClass(jniEnv, metadataClassName);
//OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray)
jmethodID metadataConstructor = (*jniEnv)->GetMethodID(jniEnv, metadataClazz, "<init>",
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V");
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V");

// Get metadata
OrtModelMetadata* metadata;
Expand All @@ -402,6 +402,11 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
jstring graphStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));

// Read out the graph description and convert it to a java.lang.String
checkOrtStatus(jniEnv,api,api->ModelMetadataGetGraphDescription(metadata, allocator, &charBuffer));
jstring graphDescStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer));

// Read out the domain and convert it to a java.lang.String
checkOrtStatus(jniEnv,api,api->ModelMetadataGetDomain(metadata, allocator, &charBuffer));
jstring domainStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer);
Expand Down Expand Up @@ -449,8 +454,8 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata
}

// Invoke the metadata constructor
//OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray)
jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, domainStr, descriptionStr, (jlong) version, customArray);
//OnnxModelMetadata(String producerName, String graphName, String graphDescription, String domain, String description, long version, String[] customMetadataArray)
jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, graphDescStr, domainStr, descriptionStr, (jlong) version, customArray);

// Release the metadata
api->ReleaseModelMetadata(metadata);
Expand Down
32 changes: 31 additions & 1 deletion java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2021, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -1168,6 +1168,36 @@ public void testLoadCustomLibrary() throws OrtException {
}
}

@Test
public void testModelMetadata() throws OrtException {
String modelPath = getResourcePath("/model_with_valid_ort_config_json.onnx").toString();

try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelMetadata")) {
try (OrtSession session = env.createSession(modelPath)) {
OnnxModelMetadata modelMetadata = session.getMetadata();

Assertions.assertEquals(1, modelMetadata.getVersion());

Assertions.assertEquals("Hari", modelMetadata.getProducerName());

Assertions.assertEquals("matmul test", modelMetadata.getGraphName());

Assertions.assertEquals("", modelMetadata.getDomain());

Assertions.assertEquals(
"This is a test model with a valid ORT config Json", modelMetadata.getDescription());

Assertions.assertEquals("graph description", modelMetadata.getGraphDescription());

Assertions.assertEquals(2, modelMetadata.getCustomMetadata().size());
Assertions.assertEquals("dummy_value", modelMetadata.getCustomMetadata().get("dummy_key"));
Assertions.assertEquals(
"{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}",
modelMetadata.getCustomMetadata().get("ort_config"));
}
}
}

@Test
public void testModelInputBOOL() throws OrtException {
// model takes 1x5 input of fixed type, echoes back
Expand Down

0 comments on commit e47434e

Please sign in to comment.