Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpu and memory optimization #67

Merged
merged 2 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.function.Function;

final class ApiImpl implements Api {
Expand Down Expand Up @@ -306,19 +307,19 @@ void checkStatus(Addressable rawAddress) {
throw new OnnxRuntimeException(code, message);
}

MemoryAddress create(MemorySession allocator, Function<MemoryAddress, Addressable> constructor) {
MemoryAddress create(SegmentAllocator allocator, Function<MemoryAddress, Addressable> constructor) {
MemorySegment pointer = allocator.allocate(C_POINTER);
checkStatus(constructor.apply(pointer.address()));
return pointer.getAtIndex(C_POINTER, 0);
}

int extractInt(MemorySession allocator, Function<MemoryAddress, Addressable> method) {
int extractInt(SegmentAllocator allocator, Function<MemoryAddress, Addressable> method) {
MemorySegment pointer = allocator.allocate(C_INT);
checkStatus(method.apply(pointer.address()));
return pointer.getAtIndex(C_INT, 0);
}

long extractLong(MemorySession allocator, Function<MemoryAddress, Addressable> method) {
long extractLong(SegmentAllocator allocator, Function<MemoryAddress, Addressable> method) {
MemorySegment pointer = allocator.allocate(C_LONG);
checkStatus(method.apply(pointer.address()));
return pointer.getAtIndex(C_LONG, 0);
Expand Down
28 changes: 13 additions & 15 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxMapImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -49,11 +50,11 @@ static final OnnxValueImpl fromTypeInfo(MapInfoImpl mapInfo) {
}
}

private final T newKeyVector(int size, MemorySession scope) {
private final T newKeyVector(int size, SegmentAllocator scope) {
return keyVectorFactory.apply(TensorInfoImpl.of(mapInfo.getKeyType(), size, scope));
}

private final OnnxTensorImpl newValueVector(int size, MemorySession scope) {
private final OnnxTensorImpl newValueVector(int size, SegmentAllocator scope) {
return OnnxTensorImpl.fromTypeInfo(
TensorInfoImpl.of(mapInfo.getValueType().getTensorInfo().getType(), size, scope));
}
Expand Down Expand Up @@ -93,7 +94,7 @@ public OnnxTypedMap<String> asStringMap() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
int size = data.size();
T keyVector = newKeyVector(size, allocator);
implodeKeyVector(keyVector, data.keySet());
Expand All @@ -105,29 +106,26 @@ public MemoryAddress toNative(
MemorySegment kvArray = allocator.allocateArray(C_POINTER, 2);
kvArray.setAtIndex(C_POINTER, 0, keyAddress);
kvArray.setAtIndex(C_POINTER, 1, valueAddress);
MemoryAddress value = api.create(
allocator, out -> api.CreateValue.apply(kvArray.address(), 2, OnnxType.MAP.getNumber(), out));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(value);
});
return value;
return api.create(allocator, out -> api.CreateValue.apply(kvArray.address(), 2, OnnxType.MAP.getNumber(), out));
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
MemoryAddress keyAddress = api.create(allocator, out -> api.GetValue.apply(address, 0, ortAllocator, out));
MemoryAddress valueAddress = api.create(allocator, out -> api.GetValue.apply(address, 1, ortAllocator, out));
MemoryAddress keyInfo = api.create(allocator, out -> api.GetTensorTypeAndShape.apply(keyAddress, out));
int size =
Math.toIntExact(api.extractLong(allocator, out -> api.GetTensorShapeElementCount.apply(keyInfo, out)));
T keyVector = newKeyVector(size, allocator);
OnnxTensorImpl valueVector = newValueVector(size, allocator);
keyVector.fromNative(api, ortAllocator, keyAddress, allocator);
valueVector.fromNative(api, ortAllocator, valueAddress, allocator);
keyVector.fromNative(api, ortAllocator, keyAddress, allocator, session);
valueVector.fromNative(api, ortAllocator, valueAddress, allocator, session);
valueVector.getScalars(explodeKeyVector(keyVector).map(this::set));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(address);
});
}

@Override
Expand Down
18 changes: 12 additions & 6 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxOpaqueImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.nio.ByteBuffer;

final class OnnxOpaqueImpl extends OnnxValueImpl implements OnnxOpaque {
Expand All @@ -18,10 +19,10 @@ final class OnnxOpaqueImpl extends OnnxValueImpl implements OnnxOpaque {
this.opaqueInfo = opaqueInfo;
}

// @Override
// public OnnxOpaque asOpaque() {
// return this;
// }
// @Override
// public OnnxOpaque asOpaque() {
// return this;
// }

@Override
public OpaqueInfo getInfo() {
Expand All @@ -35,13 +36,18 @@ public String toString() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
// TODO Auto-generated method stub
return null;
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
// TODO Auto-generated method stub

}
Expand Down
10 changes: 8 additions & 2 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxOptionalImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.NoSuchElementException;

final class OnnxOptionalImpl extends OnnxValueImpl implements OnnxOptional {
Expand Down Expand Up @@ -55,13 +56,18 @@ public OnnxValue set() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
// TODO Auto-generated method stub
return null;
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
// TODO Auto-generated method stub

}
Expand Down
21 changes: 10 additions & 11 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxSequenceImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand Down Expand Up @@ -57,36 +58,34 @@ public OnnxValue add() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
int size = data.size();
MemorySegment valuesArray = allocator.allocateArray(C_POINTER, size);
for (int i = 0; i < size; i++) {
OnnxValueImpl value = data.get(i);
valuesArray.setAtIndex(C_POINTER, i, value.toNative(api, ortAllocator, memoryInfo, allocator));
}
MemoryAddress value = api.create(
return api.create(
allocator,
out -> api.CreateValue.apply(valuesArray.address(), size, OnnxType.SEQUENCE.getNumber(), out));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(value);
});
return value;
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
long outputs = api.extractLong(allocator, out -> api.GetValueCount.apply(address, out));
for (int i = 0; i < outputs; i++) {
final int index = i;
MemoryAddress valueAddress =
api.create(allocator, out -> api.GetValue.apply(address, index, ortAllocator, out));
OnnxValueImpl value = OnnxValueImpl.fromTypeInfo(typeInfo);
value.fromNative(api, ortAllocator, valueAddress, allocator);
value.fromNative(api, ortAllocator, valueAddress, allocator, session);
data.add(value);
}
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(address);
});
}

@Override
Expand Down
20 changes: 9 additions & 11 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxTensorBufferImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.nio.Buffer;
import java.util.function.IntFunction;

Expand All @@ -26,13 +27,13 @@ public final String toString() {

@Override
public final MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
MemorySegment rawInputData = getMemorySegment();
// TODO: move value layout to this class?
MemorySegment inputData =
allocator.allocateArray(tensorInfo.getType().getValueLayout(), rawInputData.byteSize());
inputData.copyFrom(rawInputData);
MemoryAddress tensor = api.create(
return api.create(
allocator,
out -> api.CreateTensorWithDataAsOrtValue.apply(
memoryInfo,
Expand All @@ -42,21 +43,18 @@ public final MemoryAddress toNative(
tensorInfo.getShape().size(),
tensorInfo.getType().getNumber(),
out));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(tensor);
});
return tensor;
}

@Override
public final void fromNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
MemoryAddress floatOutput = api.create(allocator, out -> api.GetTensorMutableData.apply(address, out));
MemorySegment segment = MemorySegment.ofAddress(floatOutput, tensorInfo.getByteCount(), allocator);
MemorySegment segment = MemorySegment.ofAddress(floatOutput, tensorInfo.getByteCount(), session);
getMemorySegment().copyFrom(segment);
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(address);
});
}

protected abstract MemorySegment getMemorySegment();
Expand Down
17 changes: 8 additions & 9 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxTensorStringImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
Expand All @@ -36,8 +37,7 @@ public String[] getStringBuffer() {

@Override
public MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession allocator) {

ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator) {
int numOutputs = buffer.length;
MemorySegment stringArray = allocator.allocateArray(C_POINTER, numOutputs);
for (int i = 0; i < numOutputs; i++) {
Expand All @@ -52,14 +52,16 @@ public MemoryAddress toNative(
tensorInfo.getType().getNumber(),
out));
api.checkStatus(api.FillStringTensor.apply(tensor, stringArray.address(), numOutputs));
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(tensor);
});
return tensor;
}

@Override
public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession allocator) {
public void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session) {
int numOutputs = buffer.length;
for (int i = 0; i < numOutputs; i++) {
final long index = i;
Expand All @@ -70,9 +72,6 @@ public void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress ad
api.checkStatus(api.GetStringTensorElement.apply(address, length, index, output.address()));
buffer[i] = output.getUtf8String(0);
}
allocator.addCloseAction(() -> {
api.ReleaseValue.apply(address);
});
}

@Override
Expand Down
10 changes: 8 additions & 2 deletions src/main/java/com/jyuzawa/onnxruntime/OnnxValueImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import java.lang.foreign.MemoryAddress;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.NoSuchElementException;

abstract class OnnxValueImpl implements OnnxValue {
Expand Down Expand Up @@ -52,9 +53,14 @@ public OnnxMap asMap() {
// }

abstract MemoryAddress toNative(
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, MemorySession scope);
ApiImpl api, MemoryAddress ortAllocator, MemoryAddress memoryInfo, SegmentAllocator allocator);

abstract void fromNative(ApiImpl api, MemoryAddress ortAllocator, MemoryAddress address, MemorySession scope);
abstract void fromNative(
ApiImpl api,
MemoryAddress ortAllocator,
MemoryAddress address,
SegmentAllocator allocator,
MemorySession session);

static final OnnxValueImpl fromTypeInfo(TypeInfoImpl typeInfo) {
OnnxType type = typeInfo.getType();
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/jyuzawa/onnxruntime/TensorInfoImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.MemorySession;
import java.lang.foreign.SegmentAllocator;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand All @@ -31,7 +31,7 @@ final class TensorInfoImpl implements TensorInfo {
this.elementCount = elementCount;
}

static TensorInfoImpl of(OnnxTensorElementDataType type, long elementCount, MemorySession scope) {
static TensorInfoImpl of(OnnxTensorElementDataType type, long elementCount, SegmentAllocator scope) {
MemorySegment shapeData = scope.allocateArray(C_LONG, new long[] {elementCount});
return new TensorInfoImpl(type, shapeData, 1, elementCount);
}
Expand Down
Loading