Skip to content

Commit

Permalink
Merge pull request #66 from yuzawa-san/perf
Browse files Browse the repository at this point in the history
performance improvements
  • Loading branch information
yuzawa-san authored Nov 18, 2022
2 parents 12fd74f + 8d368a1 commit 40e14ef
Show file tree
Hide file tree
Showing 26 changed files with 255 additions and 192 deletions.
4 changes: 2 additions & 2 deletions src/main/java/com/jyuzawa/onnxruntime/MapInfoImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
final class MapInfoImpl implements MapInfo {

private final OnnxTensorElementDataType keyType;
private final TypeInfo typeInfo;
private final TypeInfoImpl typeInfo;

MapInfoImpl(OnnxTensorElementDataType keyType, TypeInfoImpl typeInfo) {
this.keyType = keyType;
Expand All @@ -20,7 +20,7 @@ public OnnxTensorElementDataType getKeyType() {
}

@Override
public TypeInfo getValueType() {
public TypeInfoImpl getValueType() {
return typeInfo;
}

Expand Down
4 changes: 3 additions & 1 deletion src/main/java/com/jyuzawa/onnxruntime/ModelMetadataImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ final class ModelMetadataImpl implements ModelMetadata {
private final long version;
private final Map<String, String> customMetadata;

ModelMetadataImpl(ApiImpl api, MemoryAddress metadata, MemoryAddress ortAllocator) {
ModelMetadataImpl(ApiImpl api, MemoryAddress session, MemoryAddress ortAllocator) {
try (MemorySession allocator = MemorySession.openConfined()) {
MemoryAddress metadata = api.create(allocator, out -> api.SessionGetModelMetadata.apply(session, out));
allocator.addCloseAction(() -> api.ReleaseModelMetadata.apply(metadata));
{
MemoryAddress pointer = api.create(
allocator, out -> api.ModelMetadataGetDescription.apply(metadata, ortAllocator, out));
Expand Down
10 changes: 7 additions & 3 deletions src/main/java/com/jyuzawa/onnxruntime/NodeInfoImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
*/
package com.jyuzawa.onnxruntime;

import java.lang.foreign.MemorySegment;

final class NodeInfoImpl implements NodeInfo {

private final String name;
private final TypeInfo typeInfo;
final MemorySegment nameSegment;
private final TypeInfoImpl typeInfo;

NodeInfoImpl(String name, TypeInfo typeInfo) {
NodeInfoImpl(String name, MemorySegment nameSegment, TypeInfoImpl typeInfo) {
this.name = name;
this.nameSegment = nameSegment;
this.typeInfo = typeInfo;
}

Expand All @@ -20,7 +24,7 @@ public String getName() {
}

@Override
public TypeInfo getTypeInfo() {
public TypeInfoImpl getTypeInfo() {
return typeInfo;
}

Expand Down
24 changes: 12 additions & 12 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxMapImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@

abstract class OnnxMapImpl<K, T extends OnnxTensorImpl> extends OnnxValueImpl implements OnnxMap, OnnxTypedMap<K> {

private final Function<TensorInfo, T> keyVectorFactory;
private final Function<TensorInfoImpl, T> keyVectorFactory;
private final Map<K, OnnxTensorImpl> data;
private final Map<K, OnnxValue> unmodifiableData;
protected final MapInfo mapInfo;
protected final MapInfoImpl mapInfo;

protected OnnxMapImpl(MapInfo mapInfo, Function<TensorInfo, T> keyVectorFactory) {
protected OnnxMapImpl(MapInfoImpl mapInfo, Function<TensorInfoImpl, T> keyVectorFactory) {
super(OnnxType.MAP);
this.keyVectorFactory = keyVectorFactory;
this.data = new LinkedHashMap<>();
this.unmodifiableData = Collections.unmodifiableMap(data);
this.mapInfo = mapInfo;
}

static final OnnxValueImpl fromTypeInfo(MapInfo mapInfo) {
static final OnnxValueImpl fromTypeInfo(MapInfoImpl mapInfo) {
TypeInfo valueType = mapInfo.getValueType();
if (valueType.getType() != OnnxType.TENSOR || valueType.getTensorInfo().getElementCount() != 1) {
throw new UnsupportedOperationException("OnnxMap only supports scalar values");
Expand All @@ -49,13 +49,13 @@ static final OnnxValueImpl fromTypeInfo(MapInfo mapInfo) {
}
}

private final T newKeyVector(int size) {
return keyVectorFactory.apply(new TensorInfoImpl(mapInfo.getKeyType(), size));
private final T newKeyVector(int size, MemorySession scope) {
return keyVectorFactory.apply(TensorInfoImpl.of(mapInfo.getKeyType(), size, scope));
}

private final OnnxTensorImpl newValueVector(int size) {
private final OnnxTensorImpl newValueVector(int size, MemorySession scope) {
return OnnxTensorImpl.fromTypeInfo(
new TensorInfoImpl(mapInfo.getValueType().getTensorInfo().getType(), size));
TensorInfoImpl.of(mapInfo.getValueType().getTensorInfo().getType(), size, scope));
}

protected abstract void implodeKeyVector(T keyVector, Set<K> keys);
Expand Down Expand Up @@ -95,10 +95,10 @@ public OnnxTypedMap<String> asStringMap() {
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
int size = data.size();
T keyVector = newKeyVector(size);
T keyVector = newKeyVector(size, allocator);
implodeKeyVector(keyVector, data.keySet());
OnnxTensorImpl valueVector = OnnxTensorImpl.fromTypeInfo(
new TensorInfoImpl(mapInfo.getValueType().getTensorInfo().getType(), size));
TensorInfoImpl.of(mapInfo.getValueType().getTensorInfo().getType(), size, allocator));
valueVector.putScalars(data.values());
MemoryAddress keyAddress = keyVector.toNative(api, ortAllocator, memoryInfo, allocator);
MemoryAddress valueAddress = valueVector.toNative(api, ortAllocator, memoryInfo, allocator);
Expand All @@ -120,8 +120,8 @@ public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress ad
MemoryAddress keyInfo = api.create(allocator, out -> api.GetTensorTypeAndShape.apply(keyAddress, out));
int size =
Math.toIntExact(api.extractLong(allocator, out -> api.GetTensorShapeElementCount.apply(keyInfo, out)));
T keyVector = newKeyVector(size);
OnnxTensorImpl valueVector = newValueVector(size);
T keyVector = newKeyVector(size, allocator);
OnnxTensorImpl valueVector = newValueVector(size, allocator);
keyVector.fromNative(api, ortAllocator, keyAddress, allocator);
valueVector.fromNative(api, ortAllocator, valueAddress, allocator);
valueVector.getScalars(explodeKeyVector(keyVector).map(this::set));
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/com/jyuzawa/onnxruntime/OnnxMapLongImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

final class OnnxMapLongImpl extends OnnxMapImpl<Long, OnnxTensorLongImpl> {

OnnxMapLongImpl(MapInfo mapInfo) {
OnnxMapLongImpl(MapInfoImpl mapInfo) {
super(mapInfo, OnnxTensorLongImpl::new);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

final class OnnxMapStringImpl extends OnnxMapImpl<String, OnnxTensorStringImpl> {

OnnxMapStringImpl(MapInfo mapInfo) {
OnnxMapStringImpl(MapInfoImpl mapInfo) {
super(mapInfo, OnnxTensorStringImpl::new);
}

Expand Down
12 changes: 6 additions & 6 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxOptionalImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
final class OnnxOptionalImpl extends OnnxValueImpl implements OnnxOptional {

private OnnxValueImpl data;
private final TypeInfo typeInfo;
private final TypeInfoImpl typeInfo;

OnnxOptionalImpl(TypeInfo typeInfo) {
OnnxOptionalImpl(TypeInfoImpl typeInfo) {
super(OnnxType.OPTIONAL);
this.typeInfo = typeInfo;
}
Expand All @@ -23,10 +23,10 @@ public String toString() {
return "{OnnxOptional: info=" + typeInfo + ", data=" + data + "}";
}

// @Override
// public OnnxOptional asOptional() {
// return this;
// }
// @Override
// public OnnxOptional asOptional() {
// return this;
// }

@Override
public TypeInfo getInfo() {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxSequenceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ final class OnnxSequenceImpl extends OnnxValueImpl implements OnnxSequence {

private final List<OnnxValueImpl> data;
private final List<OnnxValue> unmodifiableData;
private final TypeInfo typeInfo;
private final TypeInfoImpl typeInfo;

OnnxSequenceImpl(TypeInfo typeInfo) {
OnnxSequenceImpl(TypeInfoImpl typeInfo) {
super(OnnxType.SEQUENCE);
this.data = new ArrayList<>();
this.unmodifiableData = Collections.unmodifiableList(data);
Expand Down
12 changes: 3 additions & 9 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxTensorBufferImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@
*/
package com.jyuzawa.onnxruntime;

import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG;

import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.nio.Buffer;
import java.util.List;
import java.util.function.IntFunction;

abstract class OnnxTensorBufferImpl<T extends Buffer> extends OnnxTensorImpl {

protected final T buffer;

protected OnnxTensorBufferImpl(TensorInfo tensorInfo, IntFunction<T> factory) {
protected OnnxTensorBufferImpl(TensorInfoImpl tensorInfo, IntFunction<T> factory) {
super(tensorInfo);
this.buffer = factory.apply(Math.toIntExact(tensorInfo.getElementCount()));
}
Expand All @@ -35,17 +32,14 @@ public final MemoryAddress toNative(
MemorySegment inputData =
allocator.allocateArray(tensorInfo.getType().getValueLayout(), rawInputData.byteSize());
inputData.copyFrom(rawInputData);
List<Long> shape = tensorInfo.getShape();
int shapeSize = shape.size();
MemorySegment shapeData = allocator.allocateArray(C_LONG, shape(shape));
MemoryAddress tensor = api.create(
allocator,
out -> api.CreateTensorWithDataAsOrtValue.apply(
memoryInfo,
inputData.address(),
inputData.byteSize(),
shapeData.address(),
shapeSize,
tensorInfo.shapeData.address(),
tensorInfo.getShape().size(),
tensorInfo.getType().getNumber(),
out));
allocator.addCloseAction(() -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

final class OnnxTensorByteImpl extends OnnxTensorBufferImpl<ByteBuffer> {

OnnxTensorByteImpl(TensorInfo tensorInfo) {
OnnxTensorByteImpl(TensorInfoImpl tensorInfo) {
super(tensorInfo, ByteBuffer::allocate);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

final class OnnxTensorDoubleImpl extends OnnxTensorBufferImpl<DoubleBuffer> {

OnnxTensorDoubleImpl(TensorInfo tensorInfo) {
OnnxTensorDoubleImpl(TensorInfoImpl tensorInfo) {
super(tensorInfo, DoubleBuffer::allocate);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_SHORT;

import java.lang.foreign.MemorySession;
import java.lang.foreign.ValueLayout;
import java.util.Collections;

/**
* A tensor type from ONNX.
Expand Down Expand Up @@ -44,7 +44,7 @@ public enum OnnxTensorElementDataType {
private OnnxTensorElementDataType(int number, ValueLayout valueLayout) {
this.number = number;
this.valueLayout = valueLayout;
this.scalarInfo = new TensorInfoImpl(this, Collections.singletonList(1L), 1L);
this.scalarInfo = TensorInfoImpl.of(this, 1L, MemorySession.global());
}

TensorInfo getScalarInfo() {
Expand All @@ -57,7 +57,9 @@ public int getNumber() {

/**
* Get a level based off its internal number.
* @param number the internal number of the level
*
* @param number
* the internal number of the level
* @return the level, UNDEFINED if not found
*/
public static final OnnxTensorElementDataType forNumber(int number) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

final class OnnxTensorFloatImpl extends OnnxTensorBufferImpl<FloatBuffer> {

OnnxTensorFloatImpl(TensorInfo tensorInfo) {
OnnxTensorFloatImpl(TensorInfoImpl tensorInfo) {
super(tensorInfo, FloatBuffer::allocate);
}

Expand Down
16 changes: 3 additions & 13 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxTensorImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Collection;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.stream.Stream;

abstract class OnnxTensorImpl extends OnnxValueImpl implements OnnxTensor {

protected final TensorInfo tensorInfo;
protected final TensorInfoImpl tensorInfo;

protected OnnxTensorImpl(TensorInfo tensorInfo) {
protected OnnxTensorImpl(TensorInfoImpl tensorInfo) {
super(OnnxType.TENSOR);
this.tensorInfo = tensorInfo;
}
Expand Down Expand Up @@ -77,16 +76,7 @@ public String[] getStringBuffer() {

abstract void getScalars(Stream<OnnxTensorImpl> scalars);

protected static final long[] shape(List<Long> original) {
int shapeSize = original.size();
long[] shapeArray = new long[shapeSize];
for (int i = 0; i < shapeSize; i++) {
shapeArray[i] = original.get(i);
}
return shapeArray;
}

static final OnnxTensorImpl fromTypeInfo(TensorInfo tensorInfo) {
static final OnnxTensorImpl fromTypeInfo(TensorInfoImpl tensorInfo) {
OnnxTensorElementDataType type = tensorInfo.getType();
switch (type) {
case BOOL:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

final class OnnxTensorIntImpl extends OnnxTensorBufferImpl<IntBuffer> {

OnnxTensorIntImpl(TensorInfo tensorInfo) {
OnnxTensorIntImpl(TensorInfoImpl tensorInfo) {
super(tensorInfo, IntBuffer::allocate);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

final class OnnxTensorLongImpl extends OnnxTensorBufferImpl<LongBuffer> {

OnnxTensorLongImpl(TensorInfo tensorInfo) {
OnnxTensorLongImpl(TensorInfoImpl tensorInfo) {
super(tensorInfo, LongBuffer::allocate);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

final class OnnxTensorShortImpl extends OnnxTensorBufferImpl<ShortBuffer> {

OnnxTensorShortImpl(TensorInfo tensorInfo) {
OnnxTensorShortImpl(TensorInfoImpl tensorInfo) {
super(tensorInfo, ShortBuffer::allocate);
}

Expand Down
11 changes: 3 additions & 8 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxTensorStringImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package com.jyuzawa.onnxruntime;

import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_CHAR;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER;

import java.lang.foreign.MemoryAddress;
Expand All @@ -14,14 +13,13 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Stream;

final class OnnxTensorStringImpl extends OnnxTensorImpl {

private final String[] buffer;

OnnxTensorStringImpl(TensorInfo tensorInfo) {
OnnxTensorStringImpl(TensorInfoImpl tensorInfo) {
super(tensorInfo);
this.buffer = new String[Math.toIntExact(tensorInfo.getElementCount())];
}
Expand All @@ -45,15 +43,12 @@ public MemoryAddress toNative(
for (int i = 0; i < numOutputs; i++) {
stringArray.setAtIndex(C_POINTER, i, allocator.allocateUtf8String(buffer[i]));
}
List<Long> shape = tensorInfo.getShape();
int shapeSize = shape.size();
MemorySegment shapeData = allocator.allocateArray(C_LONG, shape(shape));
MemoryAddress tensor = api.create(
allocator,
out -> api.CreateTensorAsOrtValue.apply(
ortAllocator,
shapeData.address(),
shapeSize,
tensorInfo.shapeData.address(),
tensorInfo.getShape().size(),
tensorInfo.getType().getNumber(),
out));
api.checkStatus(api.FillStringTensor.apply(tensor, stringArray.address(), numOutputs));
Expand Down
Loading

0 comments on commit 40e14ef

Please sign in to comment.