From e47434ea1281d83e54ea99ef053223aae874ec91 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 28 Feb 2022 13:05:03 -0500 Subject: [PATCH] [java] Adding the graph description to the exposed model metadata. (#10318) --- .../ai/onnxruntime/OnnxModelMetadata.java | 28 +++++++++++++++- .../main/native/ai_onnxruntime_OrtSession.c | 13 +++++--- .../java/ai/onnxruntime/InferenceTest.java | 32 ++++++++++++++++++- 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxModelMetadata.java b/java/src/main/java/ai/onnxruntime/OnnxModelMetadata.java index eea45717c722e..3a06f4751cc55 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxModelMetadata.java +++ b/java/src/main/java/ai/onnxruntime/OnnxModelMetadata.java @@ -1,3 +1,7 @@ +/* + * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ package ai.onnxruntime; import java.util.Collections; @@ -17,6 +21,7 @@ public final class OnnxModelMetadata { private final String producerName; private final String graphName; + private final String graphDescription; private final String domain; private final String description; private final long version; @@ -29,6 +34,7 @@ public final class OnnxModelMetadata { * * @param producerName The model producer name. * @param graphName The model graph name. + * @param graphDescription The model graph description. * @param domain The model domain name. * @param description The model description. * @param version The model version. @@ -37,12 +43,14 @@ public final class OnnxModelMetadata { OnnxModelMetadata( String producerName, String graphName, + String graphDescription, String domain, String description, long version, String[] customMetadataArray) { this.producerName = producerName == null ? "" : producerName; this.graphName = graphName == null ? "" : graphName; + this.graphDescription = graphDescription == null ? "" : graphDescription; this.domain = domain == null ? "" : domain; this.description = description == null ? "" : description; this.version = version; @@ -66,6 +74,7 @@ public final class OnnxModelMetadata { * * @param producerName The model producer name. * @param graphName The model graph name. + * @param graphDescription The model graph name. * @param domain The model domain name. * @param description The model description. * @param version The model version. @@ -74,12 +83,14 @@ public final class OnnxModelMetadata { OnnxModelMetadata( String producerName, String graphName, + String graphDescription, String domain, String description, long version, Map customMetadata) { this.producerName = producerName == null ? "" : producerName; this.graphName = graphName == null ? "" : graphName; + this.graphDescription = graphDescription == null ? "" : graphDescription; this.domain = domain == null ? "" : domain; this.description = description == null ? "" : description; this.version = version; @@ -94,6 +105,7 @@ public final class OnnxModelMetadata { public OnnxModelMetadata(OnnxModelMetadata other) { this.producerName = other.producerName; this.graphName = other.graphName; + this.graphDescription = other.graphDescription; this.domain = other.domain; this.description = other.description; this.version = other.version; @@ -111,6 +123,7 @@ public boolean equals(Object o) { return version == that.version && producerName.equals(that.producerName) && graphName.equals(that.graphName) + && graphDescription.equals(that.graphDescription) && domain.equals(that.domain) && description.equals(that.description) && customMetadata.equals(that.customMetadata); @@ -118,7 +131,8 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(producerName, graphName, domain, description, version, customMetadata); + return Objects.hash( + producerName, graphName, graphDescription, domain, description, version, customMetadata); } /** @@ -139,6 +153,15 @@ public String getGraphName() { return graphName; } + /** + * Gets the graph description. + * + * @return The graph description. + */ + public String getGraphDescription() { + return graphDescription; + } + /** * Gets the domain. * @@ -195,6 +218,9 @@ public String toString() { + ", graphName='" + graphName + '\'' + + ", graphDescription='" + + graphDescription + + '\'' + ", domain='" + domain + '\'' diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index 8e96227e0c2a7..de24f10f56272 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2020 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -385,7 +385,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata jclass metadataClazz = (*jniEnv)->FindClass(jniEnv, metadataClassName); //OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray) jmethodID metadataConstructor = (*jniEnv)->GetMethodID(jniEnv, metadataClazz, "", - "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V"); + "(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;J[Ljava/lang/String;)V"); // Get metadata OrtModelMetadata* metadata; @@ -402,6 +402,11 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata jstring graphStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer); checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer)); + // Read out the graph description and convert it to a java.lang.String + checkOrtStatus(jniEnv,api,api->ModelMetadataGetGraphDescription(metadata, allocator, &charBuffer)); + jstring graphDescStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer); + checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,charBuffer)); + // Read out the domain and convert it to a java.lang.String checkOrtStatus(jniEnv,api,api->ModelMetadataGetDomain(metadata, allocator, &charBuffer)); jstring domainStr = (*jniEnv)->NewStringUTF(jniEnv,charBuffer); @@ -449,8 +454,8 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtSession_constructMetadata } // Invoke the metadata constructor - //OnnxModelMetadata(String producerName, String graphName, String domain, String description, long version, String[] customMetadataArray) - jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, domainStr, descriptionStr, (jlong) version, customArray); + //OnnxModelMetadata(String producerName, String graphName, String graphDescription, String domain, String description, long version, String[] customMetadataArray) + jobject metadataJava = (*jniEnv)->NewObject(jniEnv, metadataClazz, metadataConstructor, producerStr, graphStr, graphDescStr, domainStr, descriptionStr, (jlong) version, customArray); // Release the metadata api->ReleaseModelMetadata(metadata); diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 2fdc0b506d6c5..b73583cab7441 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2021, 2022, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -1168,6 +1168,36 @@ public void testLoadCustomLibrary() throws OrtException { } } + @Test + public void testModelMetadata() throws OrtException { + String modelPath = getResourcePath("/model_with_valid_ort_config_json.onnx").toString(); + + try (OrtEnvironment env = OrtEnvironment.getEnvironment("testModelMetadata")) { + try (OrtSession session = env.createSession(modelPath)) { + OnnxModelMetadata modelMetadata = session.getMetadata(); + + Assertions.assertEquals(1, modelMetadata.getVersion()); + + Assertions.assertEquals("Hari", modelMetadata.getProducerName()); + + Assertions.assertEquals("matmul test", modelMetadata.getGraphName()); + + Assertions.assertEquals("", modelMetadata.getDomain()); + + Assertions.assertEquals( + "This is a test model with a valid ORT config Json", modelMetadata.getDescription()); + + Assertions.assertEquals("graph description", modelMetadata.getGraphDescription()); + + Assertions.assertEquals(2, modelMetadata.getCustomMetadata().size()); + Assertions.assertEquals("dummy_value", modelMetadata.getCustomMetadata().get("dummy_key")); + Assertions.assertEquals( + "{\"session_options\": {\"inter_op_num_threads\": 5, \"intra_op_num_threads\": 2, \"graph_optimization_level\": 99, \"enable_profiling\": 1}}", + modelMetadata.getCustomMetadata().get("ort_config")); + } + } + } + @Test public void testModelInputBOOL() throws OrtException { // model takes 1x5 input of fixed type, echoes back