Skip to content

Commit

Permalink
Merge pull request #67 from yuzawa-san/perfier
Browse files Browse the repository at this point in the history
cpu and memory optimization
  • Loading branch information
yuzawa-san authored Nov 19, 2022
2 parents 40e14ef + d81b8a5 commit a066012
Show file tree
Hide file tree
Showing 11 changed files with 99 additions and 88 deletions.
7 changes: 4 additions & 3 deletions src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.function.Function;

final class ApiImpl implements Api {
Expand Down Expand Up @@ -306,19 +307,19 @@ void checkStatus(Addressable rawAddress) {
throw new OnnxRuntimeException(code, message);
}

MemoryAddress create(MemorySession allocator, Function<MemoryAddress, Addressable> constructor) {
MemoryAddress create(SegmentAllocator allocator, Function<MemoryAddress, Addressable> constructor) {
MemorySegment pointer = allocator.allocate(C_POINTER);
checkStatus(constructor.apply(pointer.address()));
return pointer.getAtIndex(C_POINTER, 0);
}

int extractInt(MemorySession allocator, Function<MemoryAddress, Addressable> method) {
int extractInt(SegmentAllocator allocator, Function<MemoryAddress, Addressable> method) {
MemorySegment pointer = allocator.allocate(C_INT);
checkStatus(method.apply(pointer.address()));
return pointer.getAtIndex(C_INT, 0);
}

long extractLong(MemorySession allocator, Function<MemoryAddress, Addressable> method) {
long extractLong(SegmentAllocator allocator, Function<MemoryAddress, Addressable> method) {
MemorySegment pointer = allocator.allocate(C_LONG);
checkStatus(method.apply(pointer.address()));
return pointer.getAtIndex(C_LONG, 0);
Expand Down
28 changes: 13 additions & 15 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxMapImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -49,11 +50,11 @@ static final OnnxValueImpl fromTypeInfo(MapInfoImpl mapInfo) {
}
}

private final T newKeyVector(int size, MemorySession scope) {
private final T newKeyVector(int size, SegmentAllocator scope) {
return keyVectorFactory.apply(TensorInfoImpl.of(mapInfo.getKeyType(), size, scope));
}

private final OnnxTensorImpl newValueVector(int size, MemorySession scope) {
private final OnnxTensorImpl newValueVector(int size, SegmentAllocator scope) {
return OnnxTensorImpl.fromTypeInfo(
TensorInfoImpl.of(mapInfo.getValueType().getTensorInfo().getType(), size, scope));
}
Expand Down Expand Up @@ -93,7 +94,7 @@ public OnnxTypedMap<String> asStringMap() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
int size = data.size();
T keyVector = newKeyVector(size, allocator);
implodeKeyVector(keyVector, data.keySet());
Expand All @@ -105,29 +106,26 @@ public MemoryAddress toNative(
MemorySegment kvArray = allocator.allocateArray(C_POINTER, 2);
kvArray.setAtIndex(C_POINTER, 0, keyAddress);
kvArray.setAtIndex(C_POINTER, 1, valueAddress);
MemoryAddress value = api.create(
allocator, out -> api.CreateValue.apply(kvArray.address(), 2, OnnxType.MAP.getNumber(), out));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(value);
});
return value;
return api.create(allocator, out -> api.CreateValue.apply(kvArray.address(), 2, OnnxType.MAP.getNumber(), out));
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
MemoryAddress keyAddress = api.create(allocator, out -> api.GetValue.apply(address, 0, ortAllocator, out));
MemoryAddress valueAddress = api.create(allocator, out -> api.GetValue.apply(address, 1, ortAllocator, out));
MemoryAddress keyInfo = api.create(allocator, out -> api.GetTensorTypeAndShape.apply(keyAddress, out));
int size =
Math.toIntExact(api.extractLong(allocator, out -> api.GetTensorShapeElementCount.apply(keyInfo, out)));
T keyVector = newKeyVector(size, allocator);
OnnxTensorImpl valueVector = newValueVector(size, allocator);
keyVector.fromNative(api, ortAllocator, keyAddress, allocator);
valueVector.fromNative(api, ortAllocator, valueAddress, allocator);
keyVector.fromNative(api, ortAllocator, keyAddress, allocator, session);
valueVector.fromNative(api, ortAllocator, valueAddress, allocator, session);
valueVector.getScalars(explodeKeyVector(keyVector).map(this::set));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(address);
});
}

@Override
Expand Down
18 changes: 12 additions & 6 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxOpaqueImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.nio.ByteBuffer;

final class OnnxOpaqueImpl extends OnnxValueImpl implements OnnxOpaque {
Expand All @@ -18,10 +19,10 @@ final class OnnxOpaqueImpl extends OnnxValueImpl implements OnnxOpaque {
this.opaqueInfo = opaqueInfo;
}

// @Override
// public OnnxOpaque asOpaque() {
// return this;
// }
// @Override
// public OnnxOpaque asOpaque() {
// return this;
// }

@Override
public OpaqueInfo getInfo() {
Expand All @@ -35,13 +36,18 @@ public String toString() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
// TODO Auto-generated method stub
return null;
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
// TODO Auto-generated method stub

}
Expand Down
10 changes: 8 additions & 2 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxOptionalImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.NoSuchElementException;

final class OnnxOptionalImpl extends OnnxValueImpl implements OnnxOptional {
Expand Down Expand Up @@ -55,13 +56,18 @@ public OnnxValue set() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
// TODO Auto-generated method stub
return null;
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
// TODO Auto-generated method stub

}
Expand Down
21 changes: 10 additions & 11 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxSequenceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -57,36 +58,34 @@ public OnnxValue add() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
int size = data.size();
MemorySegment valuesArray = allocator.allocateArray(C_POINTER, size);
for (int i = 0; i < size; i++) {
OnnxValueImpl value = data.get(i);
valuesArray.setAtIndex(C_POINTER, i, value.toNative(api, ortAllocator, memoryInfo, allocator));
}
MemoryAddress value = api.create(
return api.create(
allocator,
out -> api.CreateValue.apply(valuesArray.address(), size, OnnxType.SEQUENCE.getNumber(), out));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(value);
});
return value;
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
long outputs = api.extractLong(allocator, out -> api.GetValueCount.apply(address, out));
for (int i = 0; i < outputs; i++) {
final int index = i;
MemoryAddress valueAddress =
api.create(allocator, out -> api.GetValue.apply(address, index, ortAllocator, out));
OnnxValueImpl value = OnnxValueImpl.fromTypeInfo(typeInfo);
value.fromNative(api, ortAllocator, valueAddress, allocator);
value.fromNative(api, ortAllocator, valueAddress, allocator, session);
data.add(value);
}
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(address);
});
}

@Override
Expand Down
20 changes: 9 additions & 11 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxTensorBufferImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.nio.Buffer;
import java.util.function.IntFunction;

Expand All @@ -26,13 +27,13 @@ public final String toString() {

@Override
public final MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
MemorySegment rawInputData = getMemorySegment();
// TODO: move value layout to this class?
MemorySegment inputData =
allocator.allocateArray(tensorInfo.getType().getValueLayout(), rawInputData.byteSize());
inputData.copyFrom(rawInputData);
MemoryAddress tensor = api.create(
return api.create(
allocator,
out -> api.CreateTensorWithDataAsOrtValue.apply(
memoryInfo,
Expand All @@ -42,21 +43,18 @@ public final MemoryAddress toNative(
tensorInfo.getShape().size(),
tensorInfo.getType().getNumber(),
out));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(tensor);
});
return tensor;
}

@Override
public final void fromNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
MemoryAddress floatOutput = api.create(allocator, out -> api.GetTensorMutableData.apply(address, out));
MemorySegment segment = MemorySegment.ofAddress(floatOutput, tensorInfo.getByteCount(), allocator);
MemorySegment segment = MemorySegment.ofAddress(floatOutput, tensorInfo.getByteCount(), session);
getMemorySegment().copyFrom(segment);
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(address);
});
}

protected abstract MemorySegment getMemorySegment();
Expand Down
17 changes: 8 additions & 9 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxTensorStringImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
Expand All @@ -36,8 +37,7 @@ public String[] getStringBuffer() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {

ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
int numOutputs = buffer.length;
MemorySegment stringArray = allocator.allocateArray(C_POINTER, numOutputs);
for (int i = 0; i < numOutputs; i++) {
Expand All @@ -52,14 +52,16 @@ public MemoryAddress toNative(
tensorInfo.getType().getNumber(),
out));
api.checkStatus(api.FillStringTensor.apply(tensor, stringArray.address(), numOutputs));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(tensor);
});
return tensor;
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
int numOutputs = buffer.length;
for (int i = 0; i < numOutputs; i++) {
final long index = i;
Expand All @@ -70,9 +72,6 @@ public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress ad
api.checkStatus(api.GetStringTensorElement.apply(address, length, index, output.address()));
buffer[i] = output.getUtf8String(0);
}
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(address);
});
}

@Override
Expand Down
10 changes: 8 additions & 2 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxValueImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.NoSuchElementException;

abstract class OnnxValueImpl implements OnnxValue {
Expand Down Expand Up @@ -52,9 +53,14 @@ public OnnxMap asMap() {
// }

abstract MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession scope);
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator);

abstract void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession scope);
abstract void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session);

static final OnnxValueImpl fromTypeInfo(TypeInfoImpl typeInfo) {
OnnxType type = typeInfo.getType();
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/jyuzawa/onnxruntime/TensorInfoImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand All @@ -31,7 +31,7 @@ final class TensorInfoImpl implements TensorInfo {
this.elementCount = elementCount;
}

static TensorInfoImpl of(OnnxTensorElementDataType type, long elementCount, MemorySession scope) {
static TensorInfoImpl of(OnnxTensorElementDataType type, long elementCount, SegmentAllocator scope) {
MemorySegment shapeData = scope.allocateArray(C_LONG, new long[] {elementCount});
return new TensorInfoImpl(type, shapeData, 1, elementCount);
}
Expand Down
Loading

0 comments on commit a066012

Please sign in to comment.