From 19d8bc26fe706abd21d942e86edcb7020a9b4924 Mon Sep 17 00:00:00 2001 From: James Yuzawa Date: Thu, 17 Nov 2022 22:29:42 -0500 Subject: [PATCH 1/2] performance improvements --- .../com/jyuzawa/onnxruntime/MapInfoImpl.java | 4 +- .../onnxruntime/ModelMetadataImpl.java | 4 +- .../com/jyuzawa/onnxruntime/NodeInfoImpl.java | 23 ++++- .../com/jyuzawa/onnxruntime/OnnxMapImpl.java | 24 ++--- .../jyuzawa/onnxruntime/OnnxMapLongImpl.java | 2 +- .../onnxruntime/OnnxMapStringImpl.java | 2 +- .../jyuzawa/onnxruntime/OnnxOptionalImpl.java | 12 +-- .../jyuzawa/onnxruntime/OnnxSequenceImpl.java | 4 +- .../onnxruntime/OnnxTensorBufferImpl.java | 12 +-- .../onnxruntime/OnnxTensorByteImpl.java | 2 +- .../onnxruntime/OnnxTensorDoubleImpl.java | 2 +- .../OnnxTensorElementDataType.java | 8 +- .../onnxruntime/OnnxTensorFloatImpl.java | 2 +- .../jyuzawa/onnxruntime/OnnxTensorImpl.java | 16 +-- .../onnxruntime/OnnxTensorIntImpl.java | 2 +- .../onnxruntime/OnnxTensorLongImpl.java | 2 +- .../onnxruntime/OnnxTensorShortImpl.java | 2 +- .../onnxruntime/OnnxTensorStringImpl.java | 11 +-- .../jyuzawa/onnxruntime/OnnxValueImpl.java | 34 +++---- .../onnxruntime/SessionBuilderImpl.java | 2 +- .../com/jyuzawa/onnxruntime/SessionImpl.java | 37 ++++--- .../jyuzawa/onnxruntime/TensorInfoImpl.java | 24 +++-- .../onnxruntime/TransactionBuilderImpl.java | 36 ++++--- .../jyuzawa/onnxruntime/TransactionImpl.java | 98 +++++++++---------- .../com/jyuzawa/onnxruntime/TypeInfoImpl.java | 36 +++---- .../com/jyuzawa/onnxruntime/SessionTest.java | 59 +++++++++++ 26 files changed, 268 insertions(+), 192 deletions(-) diff --git a/src/main/java/com/jyuzawa/onnxruntime/MapInfoImpl.java b/src/main/java/com/jyuzawa/onnxruntime/MapInfoImpl.java index b17ed02..3131274 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/MapInfoImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/MapInfoImpl.java @@ -7,7 +7,7 @@ final class MapInfoImpl implements MapInfo { private final OnnxTensorElementDataType keyType; - private final TypeInfo typeInfo; + private final TypeInfoImpl typeInfo; MapInfoImpl(OnnxTensorElementDataType keyType, TypeInfoImpl typeInfo) { this.keyType = keyType; @@ -20,7 +20,7 @@ public OnnxTensorElementDataType getKeyType() { } @Override - public TypeInfo getValueType() { + public TypeInfoImpl getValueType() { return typeInfo; } diff --git a/src/main/java/com/jyuzawa/onnxruntime/ModelMetadataImpl.java b/src/main/java/com/jyuzawa/onnxruntime/ModelMetadataImpl.java index 50004a4..3e27fb6 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/ModelMetadataImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/ModelMetadataImpl.java @@ -24,8 +24,10 @@ final class ModelMetadataImpl implements ModelMetadata { private final long version; private final Map customMetadata; - ModelMetadataImpl(ApiImpl api, MemoryAddress metadata, MemoryAddress ortAllocator) { + ModelMetadataImpl(ApiImpl api, MemoryAddress session, MemoryAddress ortAllocator) { try (MemorySession allocator = MemorySession.openConfined()) { + MemoryAddress metadata = api.create(allocator, out -> api.SessionGetModelMetadata.apply(session, out)); + allocator.addCloseAction(() -> api.ReleaseModelMetadata.apply(metadata)); { MemoryAddress pointer = api.create( allocator, out -> api.ModelMetadataGetDescription.apply(metadata, ortAllocator, out)); diff --git a/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java b/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java index 5a7475a..9731c28 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java @@ -4,13 +4,17 @@ */ package com.jyuzawa.onnxruntime; +import java.lang.foreign.MemorySegment; + final class NodeInfoImpl implements NodeInfo { private final String name; - private final TypeInfo typeInfo; + final MemorySegment nameSegment; + private final TypeInfoImpl typeInfo; - NodeInfoImpl(String name, TypeInfo typeInfo) { + NodeInfoImpl(String name, MemorySegment nameSegment, TypeInfoImpl typeInfo) { this.name = name; + this.nameSegment = nameSegment; this.typeInfo = typeInfo; } @@ -20,7 +24,7 @@ public String getName() { } @Override - public TypeInfo getTypeInfo() { + public TypeInfoImpl getTypeInfo() { return typeInfo; } @@ -28,4 +32,17 @@ public TypeInfo getTypeInfo() { public String toString() { return "{NodeInfo: name=" + name + ", typeInfo=" + typeInfo + "}"; } + + @Override + public int hashCode() { + return name.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (o instanceof NodeInfoImpl) { + return ((NodeInfoImpl) o).name.equals(name); + } + return false; + } } diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxMapImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxMapImpl.java index b3b5ba8..a277e3b 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxMapImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxMapImpl.java @@ -20,12 +20,12 @@ abstract class OnnxMapImpl extends OnnxValueImpl implements OnnxMap, OnnxTypedMap { - private final Function keyVectorFactory; + private final Function keyVectorFactory; private final Map data; private final Map unmodifiableData; - protected final MapInfo mapInfo; + protected final MapInfoImpl mapInfo; - protected OnnxMapImpl(MapInfo mapInfo, Function keyVectorFactory) { + protected OnnxMapImpl(MapInfoImpl mapInfo, Function keyVectorFactory) { super(OnnxType.MAP); this.keyVectorFactory = keyVectorFactory; this.data = new LinkedHashMap<>(); @@ -33,7 +33,7 @@ protected OnnxMapImpl(MapInfo mapInfo, Function keyVectorFactory) this.mapInfo = mapInfo; } - static final OnnxValueImpl fromTypeInfo(MapInfo mapInfo) { + static final OnnxValueImpl fromTypeInfo(MapInfoImpl mapInfo) { TypeInfo valueType = mapInfo.getValueType(); if (valueType.getType() != OnnxType.TENSOR || valueType.getTensorInfo().getElementCount() != 1) { throw new UnsupportedOperationException("OnnxMap only supports scalar values"); @@ -49,13 +49,13 @@ static final OnnxValueImpl fromTypeInfo(MapInfo mapInfo) { } } - private final T newKeyVector(int size) { - return keyVectorFactory.apply(new TensorInfoImpl(mapInfo.getKeyType(), size)); + private final T newKeyVector(int size, MemorySession scope) { + return keyVectorFactory.apply(TensorInfoImpl.of(mapInfo.getKeyType(), size, scope)); } - private final OnnxTensorImpl newValueVector(int size) { + private final OnnxTensorImpl newValueVector(int size, MemorySession scope) { return OnnxTensorImpl.fromTypeInfo( - new TensorInfoImpl(mapInfo.getValueType().getTensorInfo().getType(), size)); + TensorInfoImpl.of(mapInfo.getValueType().getTensorInfo().getType(), size, scope)); } protected abstract void implodeKeyVector(T keyVector, Set keys); @@ -95,10 +95,10 @@ public OnnxTypedMap asStringMap() { public MemoryAddress toNative( ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) { int size = data.size(); - T keyVector = newKeyVector(size); + T keyVector = newKeyVector(size, allocator); implodeKeyVector(keyVector, data.keySet()); OnnxTensorImpl valueVector = OnnxTensorImpl.fromTypeInfo( - new TensorInfoImpl(mapInfo.getValueType().getTensorInfo().getType(), size)); + TensorInfoImpl.of(mapInfo.getValueType().getTensorInfo().getType(), size, allocator)); valueVector.putScalars(data.values()); MemoryAddress keyAddress = keyVector.toNative(api, ortAllocator, memoryInfo, allocator); MemoryAddress valueAddress = valueVector.toNative(api, ortAllocator, memoryInfo, allocator); @@ -120,8 +120,8 @@ public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress ad MemoryAddress keyInfo = api.create(allocator, out -> api.GetTensorTypeAndShape.apply(keyAddress, out)); int size = Math.toIntExact(api.extractLong(allocator, out -> api.GetTensorShapeElementCount.apply(keyInfo, out))); - T keyVector = newKeyVector(size); - OnnxTensorImpl valueVector = newValueVector(size); + T keyVector = newKeyVector(size, allocator); + OnnxTensorImpl valueVector = newValueVector(size, allocator); keyVector.fromNative(api, ortAllocator, keyAddress, allocator); valueVector.fromNative(api, ortAllocator, valueAddress, allocator); valueVector.getScalars(explodeKeyVector(keyVector).map(this::set)); diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxMapLongImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxMapLongImpl.java index ea7bbf4..ce3bd50 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxMapLongImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxMapLongImpl.java @@ -11,7 +11,7 @@ final class OnnxMapLongImpl extends OnnxMapImpl { - OnnxMapLongImpl(MapInfo mapInfo) { + OnnxMapLongImpl(MapInfoImpl mapInfo) { super(mapInfo, OnnxTensorLongImpl::new); } diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxMapStringImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxMapStringImpl.java index 71534ed..776ad5e 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxMapStringImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxMapStringImpl.java @@ -10,7 +10,7 @@ final class OnnxMapStringImpl extends OnnxMapImpl { - OnnxMapStringImpl(MapInfo mapInfo) { + OnnxMapStringImpl(MapInfoImpl mapInfo) { super(mapInfo, OnnxTensorStringImpl::new); } diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxOptionalImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxOptionalImpl.java index 35c7a8b..878ff26 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxOptionalImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxOptionalImpl.java @@ -11,9 +11,9 @@ final class OnnxOptionalImpl extends OnnxValueImpl implements OnnxOptional { private OnnxValueImpl data; - private final TypeInfo typeInfo; + private final TypeInfoImpl typeInfo; - OnnxOptionalImpl(TypeInfo typeInfo) { + OnnxOptionalImpl(TypeInfoImpl typeInfo) { super(OnnxType.OPTIONAL); this.typeInfo = typeInfo; } @@ -23,10 +23,10 @@ public String toString() { return "{OnnxOptional: info=" + typeInfo + ", data=" + data + "}"; } - // @Override - // public OnnxOptional asOptional() { - // return this; - // } + // @Override + // public OnnxOptional asOptional() { + // return this; + // } @Override public TypeInfo getInfo() { diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxSequenceImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxSequenceImpl.java index c509acc..2bb3379 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxSequenceImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxSequenceImpl.java @@ -20,9 +20,9 @@ final class OnnxSequenceImpl extends OnnxValueImpl implements OnnxSequence { private final List data; private final List unmodifiableData; - private final TypeInfo typeInfo; + private final TypeInfoImpl typeInfo; - OnnxSequenceImpl(TypeInfo typeInfo) { + OnnxSequenceImpl(TypeInfoImpl typeInfo) { super(OnnxType.SEQUENCE); this.data = new ArrayList<>(); this.unmodifiableData = Collections.unmodifiableList(data); diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorBufferImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorBufferImpl.java index acb0411..c5f22e1 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorBufferImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorBufferImpl.java @@ -4,20 +4,17 @@ */ package com.jyuzawa.onnxruntime; -import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; - import java.lang.foreign.MemoryAddress; import java.lang.foreign.MemorySegment; import java.lang.foreign.MemorySession; import java.nio.Buffer; -import java.util.List; import java.util.function.IntFunction; abstract class OnnxTensorBufferImpl extends OnnxTensorImpl { protected final T buffer; - protected OnnxTensorBufferImpl(TensorInfo tensorInfo, IntFunction factory) { + protected OnnxTensorBufferImpl(TensorInfoImpl tensorInfo, IntFunction factory) { super(tensorInfo); this.buffer = factory.apply(Math.toIntExact(tensorInfo.getElementCount())); } @@ -35,17 +32,14 @@ public final MemoryAddress toNative( MemorySegment inputData = allocator.allocateArray(tensorInfo.getType().getValueLayout(), rawInputData.byteSize()); inputData.copyFrom(rawInputData); - List shape = tensorInfo.getShape(); - int shapeSize = shape.size(); - MemorySegment shapeData = allocator.allocateArray(C_LONG, shape(shape)); MemoryAddress tensor = api.create( allocator, out -> api.CreateTensorWithDataAsOrtValue.apply( memoryInfo, inputData.address(), inputData.byteSize(), - shapeData.address(), - shapeSize, + tensorInfo.shapeData.address(), + tensorInfo.getShape().size(), tensorInfo.getType().getNumber(), out)); allocator.addCloseAction(() -> { diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorByteImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorByteImpl.java index 75fd400..6f67d0e 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorByteImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorByteImpl.java @@ -11,7 +11,7 @@ final class OnnxTensorByteImpl extends OnnxTensorBufferImpl { - OnnxTensorByteImpl(TensorInfo tensorInfo) { + OnnxTensorByteImpl(TensorInfoImpl tensorInfo) { super(tensorInfo, ByteBuffer::allocate); } diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorDoubleImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorDoubleImpl.java index 628665a..2953949 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorDoubleImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorDoubleImpl.java @@ -11,7 +11,7 @@ final class OnnxTensorDoubleImpl extends OnnxTensorBufferImpl { - OnnxTensorDoubleImpl(TensorInfo tensorInfo) { + OnnxTensorDoubleImpl(TensorInfoImpl tensorInfo) { super(tensorInfo, DoubleBuffer::allocate); } diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorElementDataType.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorElementDataType.java index b0969ea..ae160c5 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorElementDataType.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorElementDataType.java @@ -11,8 +11,8 @@ import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_SHORT; +import java.lang.foreign.MemorySession; import java.lang.foreign.ValueLayout; -import java.util.Collections; /** * A tensor type from ONNX. @@ -44,7 +44,7 @@ public enum OnnxTensorElementDataType { private OnnxTensorElementDataType(int number, ValueLayout valueLayout) { this.number = number; this.valueLayout = valueLayout; - this.scalarInfo = new TensorInfoImpl(this, Collections.singletonList(1L), 1L); + this.scalarInfo = TensorInfoImpl.of(this, 1L, MemorySession.global()); } TensorInfo getScalarInfo() { @@ -57,7 +57,9 @@ public int getNumber() { /** * Get a level based off its internal number. - * @param number the internal number of the level + * + * @param number + * the internal number of the level * @return the level, UNDEFINED if not found */ public static final OnnxTensorElementDataType forNumber(int number) { diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorFloatImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorFloatImpl.java index 2581a0b..c612dc3 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorFloatImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorFloatImpl.java @@ -11,7 +11,7 @@ final class OnnxTensorFloatImpl extends OnnxTensorBufferImpl { - OnnxTensorFloatImpl(TensorInfo tensorInfo) { + OnnxTensorFloatImpl(TensorInfoImpl tensorInfo) { super(tensorInfo, FloatBuffer::allocate); } diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorImpl.java index d58cc56..7d8e307 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorImpl.java @@ -11,15 +11,14 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.Collection; -import java.util.List; import java.util.NoSuchElementException; import java.util.stream.Stream; abstract class OnnxTensorImpl extends OnnxValueImpl implements OnnxTensor { - protected final TensorInfo tensorInfo; + protected final TensorInfoImpl tensorInfo; - protected OnnxTensorImpl(TensorInfo tensorInfo) { + protected OnnxTensorImpl(TensorInfoImpl tensorInfo) { super(OnnxType.TENSOR); this.tensorInfo = tensorInfo; } @@ -77,16 +76,7 @@ public String[] getStringBuffer() { abstract void getScalars(Stream scalars); - protected static final long[] shape(List original) { - int shapeSize = original.size(); - long[] shapeArray = new long[shapeSize]; - for (int i = 0; i < shapeSize; i++) { - shapeArray[i] = original.get(i); - } - return shapeArray; - } - - static final OnnxTensorImpl fromTypeInfo(TensorInfo tensorInfo) { + static final OnnxTensorImpl fromTypeInfo(TensorInfoImpl tensorInfo) { OnnxTensorElementDataType type = tensorInfo.getType(); switch (type) { case BOOL: diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorIntImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorIntImpl.java index 89e3899..6fc1a67 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorIntImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorIntImpl.java @@ -11,7 +11,7 @@ final class OnnxTensorIntImpl extends OnnxTensorBufferImpl { - OnnxTensorIntImpl(TensorInfo tensorInfo) { + OnnxTensorIntImpl(TensorInfoImpl tensorInfo) { super(tensorInfo, IntBuffer::allocate); } diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorLongImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorLongImpl.java index 6a3e5f6..5e31ba7 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorLongImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorLongImpl.java @@ -11,7 +11,7 @@ final class OnnxTensorLongImpl extends OnnxTensorBufferImpl { - OnnxTensorLongImpl(TensorInfo tensorInfo) { + OnnxTensorLongImpl(TensorInfoImpl tensorInfo) { super(tensorInfo, LongBuffer::allocate); } diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorShortImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorShortImpl.java index f14f114..cb3fae2 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorShortImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorShortImpl.java @@ -11,7 +11,7 @@ final class OnnxTensorShortImpl extends OnnxTensorBufferImpl { - OnnxTensorShortImpl(TensorInfo tensorInfo) { + OnnxTensorShortImpl(TensorInfoImpl tensorInfo) { super(tensorInfo, ShortBuffer::allocate); } diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorStringImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorStringImpl.java index 755da96..89ab665 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorStringImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxTensorStringImpl.java @@ -5,7 +5,6 @@ package com.jyuzawa.onnxruntime; import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_CHAR; -import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; import java.lang.foreign.MemoryAddress; @@ -14,14 +13,13 @@ import java.util.Arrays; import java.util.Collection; import java.util.Iterator; -import java.util.List; import java.util.stream.Stream; final class OnnxTensorStringImpl extends OnnxTensorImpl { private final String[] buffer; - OnnxTensorStringImpl(TensorInfo tensorInfo) { + OnnxTensorStringImpl(TensorInfoImpl tensorInfo) { super(tensorInfo); this.buffer = new String[Math.toIntExact(tensorInfo.getElementCount())]; } @@ -45,15 +43,12 @@ public MemoryAddress toNative( for (int i = 0; i < numOutputs; i++) { stringArray.setAtIndex(C_POINTER, i, allocator.allocateUtf8String(buffer[i])); } - List shape = tensorInfo.getShape(); - int shapeSize = shape.size(); - MemorySegment shapeData = allocator.allocateArray(C_LONG, shape(shape)); MemoryAddress tensor = api.create( allocator, out -> api.CreateTensorAsOrtValue.apply( ortAllocator, - shapeData.address(), - shapeSize, + tensorInfo.shapeData.address(), + tensorInfo.getShape().size(), tensorInfo.getType().getNumber(), out)); api.checkStatus(api.FillStringTensor.apply(tensor, stringArray.address(), numOutputs)); diff --git a/src/main/java/com/jyuzawa/onnxruntime/OnnxValueImpl.java b/src/main/java/com/jyuzawa/onnxruntime/OnnxValueImpl.java index 09b3ca6..6a6818f 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/OnnxValueImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/OnnxValueImpl.java @@ -36,27 +36,27 @@ public OnnxMap asMap() { throw new NoSuchElementException("OnnxValue is not a map"); } - // @Override - // public OnnxOpaque asOpaque() { - // throw new NoSuchElementException("OnnxValue is not an opaque"); - // } + // @Override + // public OnnxOpaque asOpaque() { + // throw new NoSuchElementException("OnnxValue is not an opaque"); + // } // - // @Override - // public OnnxSparseTensor asSparseTensor() { - // throw new NoSuchElementException("OnnxValue is not a sparse tensor"); - // } + // @Override + // public OnnxSparseTensor asSparseTensor() { + // throw new NoSuchElementException("OnnxValue is not a sparse tensor"); + // } // - // @Override - // public OnnxOptional asOptional() { - // throw new NoSuchElementException("OnnxValue is not an optional"); - // } + // @Override + // public OnnxOptional asOptional() { + // throw new NoSuchElementException("OnnxValue is not an optional"); + // } abstract MemoryAddress toNative( ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession scope); abstract void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession scope); - static final OnnxValueImpl fromTypeInfo(TypeInfo typeInfo) { + static final OnnxValueImpl fromTypeInfo(TypeInfoImpl typeInfo) { OnnxType type = typeInfo.getType(); switch (type) { case TENSOR: @@ -65,10 +65,10 @@ static final OnnxValueImpl fromTypeInfo(TypeInfo typeInfo) { return new OnnxSequenceImpl(typeInfo.getSequenceInfo()); case MAP: return OnnxMapImpl.fromTypeInfo(typeInfo.getMapInfo()); - // case OPAQUE: - // return new OnnxOpaqueImpl(typeInfo.getOpaqueInfo()); - // case OPTIONAL: - // return new OnnxOptionalImpl(typeInfo.getOptionalInfo()); + // case OPAQUE: + // return new OnnxOpaqueImpl(typeInfo.getOpaqueInfo()); + // case OPTIONAL: + // return new OnnxOptionalImpl(typeInfo.getOptionalInfo()); default: throw new UnsupportedOperationException("OnnxValue with type " + type + " is not supported"); } diff --git a/src/main/java/com/jyuzawa/onnxruntime/SessionBuilderImpl.java b/src/main/java/com/jyuzawa/onnxruntime/SessionBuilderImpl.java index c18b8e9..6334383 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/SessionBuilderImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/SessionBuilderImpl.java @@ -212,7 +212,7 @@ private MemoryAddress newSessionOptions(MemorySession allocator) { @Override public Session build() throws IOException { - try (MemorySession allocator = MemorySession.openShared()) { + try (MemorySession allocator = MemorySession.openConfined()) { final MemorySegment mappedBuf; if (buffer != null) { diff --git a/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java b/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java index 6279def..e9bea36 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java @@ -5,6 +5,8 @@ package com.jyuzawa.onnxruntime; import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; +import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.OrtArenaAllocator; +import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.OrtMemTypeDefault; import java.lang.foreign.Addressable; import java.lang.foreign.MemoryAddress; @@ -14,10 +16,11 @@ final class SessionImpl extends ManagedImpl implements Session { - private final NamedCollection overridableInitializers; - private final NamedCollection inputs; - private final NamedCollection outputs; + private final NamedCollection overridableInitializers; + private final NamedCollection inputs; + private final NamedCollection outputs; private final ModelMetadata modelMetadata; + private final MemoryAddress memoryInfo; private final MemoryAddress ortAllocator; public SessionImpl(ApiImpl api, MemorySession allocator, MemoryAddress session) { @@ -47,9 +50,12 @@ public SessionImpl(ApiImpl api, MemorySession allocator, MemoryAddress session) api.SessionGetOutputCount::apply, api.SessionGetOutputName::apply, api.SessionGetOutputTypeInfo::apply); - MemoryAddress metadata = api.create(allocator, out -> api.SessionGetModelMetadata.apply(session, out)); - this.modelMetadata = new ModelMetadataImpl(api, metadata, ortAllocator); - api.ReleaseModelMetadata.apply(metadata); + this.modelMetadata = new ModelMetadataImpl(api, session, ortAllocator); + this.memoryInfo = api.create( + allocator, out -> api.CreateCpuMemoryInfo.apply(OrtArenaAllocator(), OrtMemTypeDefault(), out)); + allocator.addCloseAction(() -> { + api.ReleaseMemoryInfo.apply(memoryInfo); + }); } private interface GetCount { @@ -64,7 +70,7 @@ private interface GetTypeInfo { Addressable apply(MemoryAddress session, long idx, MemoryAddress out); } - private static NamedCollection createMap( + private static NamedCollection createMap( ApiImpl api, MemorySession allocator, MemoryAddress ortAllocator, @@ -75,27 +81,29 @@ private static NamedCollection createMap( MemorySegment numInputsSegment = allocator.allocate(C_LONG); api.checkStatus(getCount.apply(session, numInputsSegment.address())); long numInputs = numInputsSegment.getAtIndex(C_LONG, 0); - LinkedHashMap inputs = new LinkedHashMap<>(); + LinkedHashMap inputs = new LinkedHashMap<>(); for (long i = 0; i < numInputs; i++) { final long j = i; MemoryAddress nameSegment = api.create(allocator, out -> getName.apply(session, j, ortAllocator, out)); String name = nameSegment.getUtf8String(0); api.checkStatus(api.AllocatorFree.apply(ortAllocator, nameSegment)); MemoryAddress typeInfoAddress = api.create(allocator, out -> getTypeInfo.apply(session, j, out)); - TypeInfoImpl typeInfo = new TypeInfoImpl(api, typeInfoAddress, ortAllocator); - inputs.put(name, new NodeInfoImpl(name, typeInfo)); + TypeInfoImpl typeInfo = new TypeInfoImpl(api, typeInfoAddress, allocator, ortAllocator); + inputs.put(name, new NodeInfoImpl(name, allocator.allocateUtf8String(name), typeInfo)); } return new NamedCollectionImpl<>(inputs); } + @SuppressWarnings("unchecked") @Override public NamedCollection getOverridableInitializers() { - return overridableInitializers; + return (NamedCollection) (NamedCollection) overridableInitializers; } + @SuppressWarnings("unchecked") @Override public NamedCollection getInputs() { - return inputs; + return (NamedCollection) (NamedCollection) inputs; } @Override @@ -103,13 +111,14 @@ public ModelMetadata getModelMetadata() { return modelMetadata; } + @SuppressWarnings("unchecked") @Override public NamedCollection getOutputs() { - return outputs; + return (NamedCollection) (NamedCollection) outputs; } @Override public Transaction.Builder newTransaction() { - return new TransactionBuilderImpl(api, address, ortAllocator, inputs, outputs); + return new TransactionBuilderImpl(api, address, memoryInfo, ortAllocator, inputs, outputs); } } diff --git a/src/main/java/com/jyuzawa/onnxruntime/TensorInfoImpl.java b/src/main/java/com/jyuzawa/onnxruntime/TensorInfoImpl.java index d69c775..34ee775 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/TensorInfoImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/TensorInfoImpl.java @@ -4,6 +4,11 @@ */ package com.jyuzawa.onnxruntime; +import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG; + +import java.lang.foreign.MemorySegment; +import java.lang.foreign.MemorySession; +import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -11,21 +16,24 @@ final class TensorInfoImpl implements TensorInfo { // TODO: symbolic dims private final OnnxTensorElementDataType type; + final MemorySegment shapeData; private final List shape; private final long elementCount; - TensorInfoImpl(OnnxTensorElementDataType type, List shape, long elementCount) { + TensorInfoImpl(OnnxTensorElementDataType type, MemorySegment shapeData, int dimCount, long elementCount) { this.type = type; - this.shape = shape; + List shape = new ArrayList<>(dimCount); + for (int i = 0; i < dimCount; i++) { + shape.add(shapeData.getAtIndex(C_LONG, i)); + } + this.shape = Collections.unmodifiableList(shape); + this.shapeData = shapeData; this.elementCount = elementCount; } - TensorInfoImpl(OnnxTensorElementDataType type, long elementCount) { - this(type, Collections.singletonList(elementCount), elementCount); - } - - TensorInfoImpl(OnnxTensorElementDataType type) { - this(type, 1); + static TensorInfoImpl of(OnnxTensorElementDataType type, long elementCount, MemorySession scope) { + MemorySegment shapeData = scope.allocateArray(C_LONG, new long[] {elementCount}); + return new TensorInfoImpl(type, shapeData, 1, elementCount); } @Override diff --git a/src/main/java/com/jyuzawa/onnxruntime/TransactionBuilderImpl.java b/src/main/java/com/jyuzawa/onnxruntime/TransactionBuilderImpl.java index 3826fc1..8a262f6 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/TransactionBuilderImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/TransactionBuilderImpl.java @@ -4,11 +4,13 @@ */ package com.jyuzawa.onnxruntime; +import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; + import com.jyuzawa.onnxruntime.Transaction.Builder; import java.lang.foreign.MemoryAddress; +import java.lang.foreign.MemorySegment; import java.lang.foreign.MemorySession; import java.util.ArrayList; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -16,11 +18,12 @@ final class TransactionBuilderImpl implements Transaction.Builder { final ApiImpl api; final MemoryAddress session; + final MemoryAddress memoryInfo; final MemoryAddress ortAllocator; - final Map inputs; - final List outputs; - private final NamedCollection allInputs; - private final NamedCollection allOutputs; + final List inputs; + final List outputs; + private final NamedCollection allInputs; + private final NamedCollection allOutputs; private OnnxRuntimeLoggingLevel logSeverityLevel; private Integer logVerbosityLevel; private String runTag; @@ -29,16 +32,18 @@ final class TransactionBuilderImpl implements Transaction.Builder { public TransactionBuilderImpl( ApiImpl api, MemoryAddress session, + MemoryAddress memoryInfo, MemoryAddress ortAllocator, - NamedCollection allInputs, - NamedCollection allOutputs) { + NamedCollection allInputs, + NamedCollection allOutputs) { this.api = api; this.session = session; + this.memoryInfo = memoryInfo; this.ortAllocator = ortAllocator; this.allInputs = allInputs; this.allOutputs = allOutputs; - this.inputs = new LinkedHashMap<>(); - this.outputs = new ArrayList<>(); + this.inputs = new ArrayList<>(1); + this.outputs = new ArrayList<>(1); } @Override @@ -52,9 +57,9 @@ public Transaction build() { return new TransactionImpl(this); } - private OnnxValue addInput(NodeInfo node) { + private OnnxValue addInput(NodeInfoImpl node) { OnnxValueImpl input = OnnxValueImpl.fromTypeInfo(node.getTypeInfo()); - inputs.put(node.getName(), input); + inputs.add(new InputTuple(node, input)); return input; } @@ -68,7 +73,7 @@ public OnnxValue addInput(int index) { return addInput(allInputs.get(index)); } - private Builder addOutput(NodeInfo node) { + private Builder addOutput(NodeInfoImpl node) { outputs.add(node); return this; } @@ -107,8 +112,9 @@ public Builder setRunConfigMap(Map config) { return this; } - MemoryAddress newRunOptions(MemorySession scope) { - MemoryAddress runOptions = api.create(scope, out -> api.CreateRunOptions.apply(out)); + MemoryAddress newRunOptions(MemorySession scope, MemorySegment memorySegment) { + api.checkStatus(api.CreateRunOptions.apply(memorySegment.address())); + MemoryAddress runOptions = memorySegment.getAtIndex(C_POINTER, 0); if (logSeverityLevel != null) { api.checkStatus(api.RunOptionsSetRunLogSeverityLevel.apply(runOptions, logSeverityLevel.getNumber())); } @@ -129,4 +135,6 @@ MemoryAddress newRunOptions(MemorySession scope) { } return runOptions; } + + record InputTuple(NodeInfoImpl nodeInfo, OnnxValueImpl value) {} } diff --git a/src/main/java/com/jyuzawa/onnxruntime/TransactionImpl.java b/src/main/java/com/jyuzawa/onnxruntime/TransactionImpl.java index e9924db..dd1dc55 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/TransactionImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/TransactionImpl.java @@ -5,76 +5,74 @@ package com.jyuzawa.onnxruntime; import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER; -import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.OrtArenaAllocator; -import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.OrtMemTypeDefault; +import com.jyuzawa.onnxruntime.TransactionBuilderImpl.InputTuple; import java.lang.foreign.MemoryAddress; import java.lang.foreign.MemorySegment; import java.lang.foreign.MemorySession; import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; final class TransactionImpl implements Transaction { - - // private final Object cancelLock; - // private MemoryAddress runOptions; + // private final Object cancelLock; + // private MemoryAddress runOptions; private final TransactionBuilderImpl builder; TransactionImpl(TransactionBuilderImpl builder) { this.builder = builder; - // this.cancelLock = new Object(); + // this.cancelLock = new Object(); } - // @Override - // public void cancel() { - // synchronized (cancelLock) { - // if (runOptions != null) { - // ApiImpl api = builder.api; - // api.checkStatus(api.RunOptionsSetTerminate.apply(runOptions)); - // } - // } - // } + // @Override + // public void cancel() { + // synchronized (cancelLock) { + // if (runOptions != null) { + // ApiImpl api = builder.api; + // api.checkStatus(api.RunOptionsSetTerminate.apply(runOptions)); + // } + // } + // } @Override public NamedCollection run() { ApiImpl api = builder.api; MemoryAddress ortAllocator = builder.ortAllocator; - try (MemorySession allocator = MemorySession.openShared()) { - MemoryAddress memoryInfo = api.create( - allocator, out -> api.CreateCpuMemoryInfo.apply(OrtArenaAllocator(), OrtMemTypeDefault(), out)); - allocator.addCloseAction(() -> { - api.ReleaseMemoryInfo.apply(memoryInfo); - }); - - Map inputs = builder.inputs; + try (MemorySession allocator = MemorySession.openConfined()) { + List inputs = builder.inputs; int numInputs = inputs.size(); - MemorySegment inputNames = allocator.allocateArray(C_POINTER, numInputs); - MemorySegment inputValues = allocator.allocateArray(C_POINTER, numInputs); - int idx = 0; - for (Map.Entry entry : inputs.entrySet()) { - MemorySegment input1 = allocator.allocateUtf8String(entry.getKey()); - inputNames.setAtIndex(C_POINTER, idx, input1); - MemoryAddress valueAddress = entry.getValue().toNative(api, ortAllocator, memoryInfo, allocator); - inputValues.setAtIndex(C_POINTER, idx, valueAddress); - idx++; + List outputs = builder.outputs; + int numOutputs = outputs.size(); + long sizeOfPointer = C_POINTER.byteSize(); + MemorySegment segment = allocator.allocateArray(C_POINTER, numInputs * 2 + numOutputs * 2 + 1); + long inputsBytes = numInputs * sizeOfPointer; + long offset = 0; + MemorySegment inputNames = segment.asSlice(offset, inputsBytes); + offset += inputsBytes; + MemorySegment inputValues = segment.asSlice(offset, inputsBytes); + offset += inputsBytes; + long outputsBytes = numOutputs * sizeOfPointer; + MemorySegment outputNames = segment.asSlice(offset, outputsBytes); + offset += outputsBytes; + MemorySegment outputValues = segment.asSlice(offset, outputsBytes); + offset += outputsBytes; + MemorySegment runOptionsSegment = segment.asSlice(offset, sizeOfPointer); + for (int i = 0; i < numOutputs; i++) { + InputTuple inputTuple = inputs.get(i); + inputNames.setAtIndex(C_POINTER, i, inputTuple.nodeInfo().nameSegment); + MemoryAddress valueAddress = + inputTuple.value().toNative(api, ortAllocator, builder.memoryInfo, allocator); + inputValues.setAtIndex(C_POINTER, i, valueAddress); } - List outputs = builder.outputs; - int numOutputs = outputs.size(); - MemorySegment outputNames = allocator.allocateArray(C_POINTER, numOutputs); for (int i = 0; i < numOutputs; i++) { - MemorySegment input1 = - allocator.allocateUtf8String(outputs.get(i).getName()); - outputNames.setAtIndex(C_POINTER, i, input1); + outputNames.setAtIndex(C_POINTER, i, outputs.get(i).nameSegment); } - MemorySegment output = allocator.allocate(C_POINTER); - MemoryAddress runOptionsAddress = builder.newRunOptions(allocator); - // synchronized (cancelLock) { - // this.runOptions = runOptionsAddress; - // } + MemoryAddress runOptionsAddress = builder.newRunOptions(allocator, runOptionsSegment); + // synchronized (cancelLock) { + // this.runOptions = runOptionsAddress; + // } try { api.checkStatus(api.Run.apply( builder.session, @@ -84,18 +82,18 @@ public NamedCollection run() { numInputs, outputNames.address(), numOutputs, - output.address())); + outputValues.address())); } finally { - // synchronized (cancelLock) { + // synchronized (cancelLock) { api.ReleaseRunOptions.apply(runOptionsAddress); - // runOptions = null; - // } + // runOptions = null; + // } } LinkedHashMap out = new LinkedHashMap<>(outputs.size()); for (int i = 0; i < outputs.size(); i++) { - MemoryAddress outputAddress = output.getAtIndex(C_POINTER, i); + MemoryAddress outputAddress = outputValues.getAtIndex(C_POINTER, i); // TODO: get typeinfo from result - NodeInfo nodeInfo = outputs.get(i); + NodeInfoImpl nodeInfo = outputs.get(i); OnnxValueImpl outputValue = OnnxValueImpl.fromTypeInfo(nodeInfo.getTypeInfo()); outputValue.fromNative(api, ortAllocator, outputAddress, allocator); out.put(nodeInfo.getName(), outputValue); diff --git a/src/main/java/com/jyuzawa/onnxruntime/TypeInfoImpl.java b/src/main/java/com/jyuzawa/onnxruntime/TypeInfoImpl.java index 6853023..eeca030 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/TypeInfoImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/TypeInfoImpl.java @@ -9,29 +9,26 @@ import java.lang.foreign.MemoryAddress; import java.lang.foreign.MemorySegment; import java.lang.foreign.MemorySession; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import java.util.NoSuchElementException; final class TypeInfoImpl implements TypeInfo { // TODO: denotation private final OnnxType type; - private final TensorInfo tensorInfo; - private final MapInfo mapInfo; - private final TypeInfo sequenceInfo; + private final TensorInfoImpl tensorInfo; + private final MapInfoImpl mapInfo; + private final TypeInfoImpl sequenceInfo; - TypeInfoImpl(ApiImpl api, MemoryAddress typeInfo, MemoryAddress ortAllocator) { + TypeInfoImpl(ApiImpl api, MemoryAddress typeInfo, MemorySession sessionAllocator, MemoryAddress ortAllocator) { try (MemorySession allocator = MemorySession.openConfined()) { allocator.addCloseAction(() -> { api.ReleaseTypeInfo.apply(typeInfo); }); this.type = OnnxType.forNumber( api.extractInt(allocator, out -> api.GetOnnxTypeFromTypeInfo.apply(typeInfo, out))); - TensorInfo tensorInfo = null; - MapInfo mapInfo = null; - TypeInfo sequenceInfo = null; + TensorInfoImpl tensorInfo = null; + MapInfoImpl mapInfo = null; + TypeInfoImpl sequenceInfo = null; if (type == OnnxType.TENSOR || type == OnnxType.SPARSETENSOR) { MemoryAddress ortTensorInfo = @@ -39,15 +36,11 @@ final class TypeInfoImpl implements TypeInfo { OnnxTensorElementDataType dataType = OnnxTensorElementDataType.forNumber( api.extractInt(allocator, out -> api.GetTensorElementType.apply(ortTensorInfo, out))); int dimCount = api.extractInt(allocator, out -> api.GetDimensionsCount.apply(ortTensorInfo, out)); - MemorySegment dims = allocator.allocateArray(C_LONG, dimCount); + MemorySegment dims = sessionAllocator.allocateArray(C_LONG, dimCount); api.checkStatus(api.GetDimensions.apply(ortTensorInfo, dims.address(), dimCount)); long elementCount = api.extractInt(allocator, out -> api.GetTensorShapeElementCount.apply(ortTensorInfo, out)); - List shape = new ArrayList<>(dimCount); - for (int i = 0; i < dimCount; i++) { - shape.add(dims.getAtIndex(C_LONG, i)); - } - tensorInfo = new TensorInfoImpl(dataType, Collections.unmodifiableList(shape), elementCount); + tensorInfo = new TensorInfoImpl(dataType, dims, dimCount, elementCount); } else if (type == OnnxType.MAP) { MemoryAddress ortMapInfo = api.create(allocator, out -> api.CastTypeInfoToMapTypeInfo.apply(typeInfo, out)); @@ -55,13 +48,14 @@ final class TypeInfoImpl implements TypeInfo { api.extractInt(allocator, out -> api.GetMapKeyType.apply(ortMapInfo, out))); MemoryAddress valueTypeAddress = api.create(allocator, out -> api.GetMapValueType.apply(ortMapInfo, out)); - mapInfo = new MapInfoImpl(keyType, new TypeInfoImpl(api, valueTypeAddress, ortAllocator)); + mapInfo = new MapInfoImpl( + keyType, new TypeInfoImpl(api, valueTypeAddress, sessionAllocator, ortAllocator)); } else if (type == OnnxType.SEQUENCE) { MemoryAddress ortSequenceInfo = api.create(allocator, out -> api.CastTypeInfoToSequenceTypeInfo.apply(typeInfo, out)); MemoryAddress valueTypeAddress = api.create(allocator, out -> api.GetSequenceElementType.apply(ortSequenceInfo, out)); - sequenceInfo = new TypeInfoImpl(api, valueTypeAddress, ortAllocator); + sequenceInfo = new TypeInfoImpl(api, valueTypeAddress, sessionAllocator, ortAllocator); } else { throw new UnsupportedOperationException("unsupported type: " + type); } @@ -77,7 +71,7 @@ public OnnxType getType() { } @Override - public TensorInfo getTensorInfo() { + public TensorInfoImpl getTensorInfo() { if (tensorInfo == null) { throw new NoSuchElementException("tensor"); } @@ -85,7 +79,7 @@ public TensorInfo getTensorInfo() { } @Override - public MapInfo getMapInfo() { + public MapInfoImpl getMapInfo() { if (mapInfo == null) { throw new NoSuchElementException("map"); } @@ -105,7 +99,7 @@ public String toString() { } @Override - public TypeInfo getSequenceInfo() { + public TypeInfoImpl getSequenceInfo() { if (sequenceInfo == null) { throw new NoSuchElementException("sequence"); } diff --git a/src/test/java/com/jyuzawa/onnxruntime/SessionTest.java b/src/test/java/com/jyuzawa/onnxruntime/SessionTest.java index b9a7a4f..1725da4 100644 --- a/src/test/java/com/jyuzawa/onnxruntime/SessionTest.java +++ b/src/test/java/com/jyuzawa/onnxruntime/SessionTest.java @@ -162,6 +162,65 @@ public void infoTest() throws IOException { } } + @Test + public void multiTest() throws IOException { + String inputName1 = "in1"; + String inputName2 = "in2"; + String outputName1 = "out1"; + String outputName2 = "out2"; + TypeProto type = TypeProto.newBuilder() + .setTensorType(Tensor.newBuilder() + .setElemType(DataType.FLOAT_VALUE) + .setShape(TensorShapeProto.newBuilder() + .addDim(Dimension.newBuilder().setDimValue(1)) + .addDim(Dimension.newBuilder().setDimValue(3)))) + .build(); + ByteBuffer byteBuffer = ModelProto.newBuilder() + .setIrVersion(8) + .addOpsetImport(OperatorSetIdProto.newBuilder().setVersion(10)) + .setGraph(GraphProto.newBuilder() + .addNode(NodeProto.newBuilder() + .addInput(inputName1) + .addOutput(outputName1) + .setOpType("Identity")) + .addInput( + ValueInfoProto.newBuilder().setName(inputName1).setType(type)) + .addOutput( + ValueInfoProto.newBuilder().setName(outputName1).setType(type)) + .addNode(NodeProto.newBuilder() + .addInput(inputName2) + .addOutput(outputName2) + .setOpType("Identity")) + .addInput( + ValueInfoProto.newBuilder().setName(inputName2).setType(type)) + .addOutput( + ValueInfoProto.newBuilder().setName(outputName2).setType(type))) + .build() + .toByteString() + .asReadOnlyByteBuffer(); + try (Session session = + environment.newSession().setByteBuffer(byteBuffer).build()) { + Transaction.Builder txn = session.newTransaction(); + float[] rawInput1 = new float[] {554354, 52345234, 143646}; + txn.addInput(0).asTensor().getFloatBuffer().put(rawInput1); + float[] rawInput2 = new float[] {5346, 62346, 2345}; + txn.addInput(1).asTensor().getFloatBuffer().put(rawInput2); + txn.addOutput(0); + txn.addOutput(1); + NamedCollection output = txn.build().run(); + float[] rawOutput1 = new float[3]; + OnnxValue outputValue1 = output.get(0); + OnnxTensor outputTensor1 = outputValue1.asTensor(); + outputTensor1.getFloatBuffer().get(rawOutput1); + assertTrue(Arrays.equals(rawInput1, rawOutput1)); + float[] rawOutput2 = new float[3]; + OnnxValue outputValue2 = output.get(1); + OnnxTensor outputTensor2 = outputValue2.asTensor(); + outputTensor2.getFloatBuffer().get(rawOutput2); + assertTrue(Arrays.equals(rawInput2, rawOutput2)); + } + } + @Test public void floatTest() throws IOException { TypeProto type = TypeProto.newBuilder() From 8d368a173811082f953321d420a57e83364ea6a5 Mon Sep 17 00:00:00 2001 From: James Yuzawa Date: Thu, 17 Nov 2022 22:48:42 -0500 Subject: [PATCH 2/2] Update NodeInfoImpl.java --- .../java/com/jyuzawa/onnxruntime/NodeInfoImpl.java | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java b/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java index 9731c28..7de583e 100644 --- a/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java +++ b/src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java @@ -32,17 +32,4 @@ public TypeInfoImpl getTypeInfo() { public String toString() { return "{NodeInfo: name=" + name + ", typeInfo=" + typeInfo + "}"; } - - @Override - public int hashCode() { - return name.hashCode(); - } - - @Override - public boolean equals(Object o) { - if (o instanceof NodeInfoImpl) { - return ((NodeInfoImpl) o).name.equals(name); - } - return false; - } }